Skip to content

Commit ddcc470

Browse files
committed
refactor: simplify StorageManager naming
Signed-off-by: ji-huazhong <hzji210@gmail.com>
1 parent b266d39 commit ddcc470

31 files changed

Lines changed: 188 additions & 261 deletions

README.md

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ TransferQueue offers **fine-grained, sub-sample-level** data management and **lo
5050

5151
### Control Plane: Panoramic Data Management
5252

53-
In the control plane, `TransferQueueController` tracks the **production status** and **consumption status** of each training sample as metadata. Once all required data fields are ready (i.e., written to the `TransferQueueStorageManager`), the data sample can be consumed by downstream tasks.
53+
In the control plane, `TransferQueueController` tracks the **production status** and **consumption status** of each training sample as metadata. Once all required data fields are ready (i.e., written to the `StorageManager`), the data sample can be consumed by downstream tasks.
5454

5555
We also track the consumption history for each computational task (e.g., `generate_sequences`, `compute_log_prob`, etc.). Therefore, even when different computational tasks require the same data field, they can consume the data independently without interfering with each other.
5656

@@ -66,7 +66,7 @@ To make the data retrieval process more customizable, we provide a `Sampler` cla
6666

6767
In the data plane, we utilize a pluggable design, enabling TransferQueue to integrate with different storage backends based on user requirements.
6868

69-
Specifically, we provide a `TransferQueueStorageManager` abstraction class that defines the core APIs as follows:
69+
Specifically, we provide a `StorageManager` abstraction class that defines the core APIs as follows:
7070

7171
- `async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None`
7272
- `async def get_data(self, metadata: BatchMeta) -> TensorDict`
@@ -298,21 +298,19 @@ The data plane is organized as follows:
298298
│ │── simple_backend.py # Default distributed storage backend (SimpleStorageUnit) by TQ
299299
│ ├── managers/ # Managers are upper level interfaces that encapsulate the interaction logic with TQ system.
300300
│ │ ├── __init__.py
301-
│ │ ├──base.py # TransferQueueStorageManager, KVStorageManager
302-
│ │ ├──simple_backend_manager.py # AsyncSimpleStorageManager
301+
│ │ ├──base.py # StorageManager, KVStorageManager, StorageManagerFactory
302+
│ │ ├──simple_storage_manager.py # AsyncSimpleStorageManager
303303
│ │ ├──yuanrong_manager.py # YuanrongStorageManager
304-
│ │ ├──mooncake_manager.py # MooncakeStorageManager
305-
│ │ └──factory.py # TransferQueueStorageManagerFactory
304+
│ │ └──mooncake_manager.py # MooncakeStorageManager
306305
│ └── clients/ # Clients are lower level interfaces that directly manipulate the target storage backend.
307306
│ │ ├── __init__.py
308-
│ │ ├── base.py # TransferQueueStorageKVClient
307+
│ │ ├── base.py # StorageKVClient, StorageKVClientFactory
309308
│ │ ├── yuanrong_client.py # YuanrongStorageClient
310309
│ │ ├── mooncake_client.py # MooncakeStorageClient
311-
│ │ ├── ray_storage_client.py # RayStorageClient
312-
│ │ └── factory.py # TransferQueueStorageClientFactory
310+
│ │ └── ray_storage_client.py # RayStorageClient
313311
```
314312

315-
To integrate TransferQueue with a custom storage backend, start by implementing a subclass that inherits from `TransferQueueStorageManager`. This subclass acts as an adapter between the TransferQueue system and the target storage backend. For KV-based storage backends, you can simply inherit from `KVStorageManager`, which can serve as the general manager for all KV-based backends.
313+
To integrate TransferQueue with a custom storage backend, start by implementing a subclass that inherits from `StorageManager`. This subclass acts as an adapter between the TransferQueue system and the target storage backend. For KV-based storage backends, you can simply inherit from `KVStorageManager`, which can serve as the general manager for all KV-based backends.
316314

317315
Distributed storage backends often come with their own native clients serving as the interface of the storage system. In such cases, a low-level adapter for this client can be written, following the examples provided in the `storage/clients` directory.
318316

scripts/put_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
from transfer_queue import TransferQueueClient
3232
from transfer_queue.controller import TransferQueueController
33-
from transfer_queue.storage.simple_backend import SimpleStorageUnit
33+
from transfer_queue.storage.simple_storage import SimpleStorageUnit
3434
from transfer_queue.utils.common import get_placement_group
3535
from transfer_queue.utils.zmq_utils import process_zmq_server_info
3636

tests/test_async_simple_storage_manager.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from transfer_queue.metadata import BatchMeta
2626
from transfer_queue.storage import AsyncSimpleStorageManager
27-
from transfer_queue.utils.enum_utils import TransferQueueRole
27+
from transfer_queue.utils.enum_utils import Role
2828
from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType, ZMQServerInfo
2929

3030

@@ -35,13 +35,13 @@ async def mock_async_storage_manager():
3535
# Mock storage unit infos
3636
storage_unit_infos = {
3737
"storage_0": ZMQServerInfo(
38-
role=TransferQueueRole.STORAGE,
38+
role=Role.STORAGE,
3939
id="storage_0",
4040
ip="127.0.0.1",
4141
ports={"put_get_socket": 12345},
4242
),
4343
"storage_1": ZMQServerInfo(
44-
role=TransferQueueRole.STORAGE,
44+
role=Role.STORAGE,
4545
id="storage_1",
4646
ip="127.0.0.1",
4747
ports={"put_get_socket": 12346},
@@ -50,7 +50,7 @@ async def mock_async_storage_manager():
5050

5151
# Mock controller info
5252
controller_info = ZMQServerInfo(
53-
role=TransferQueueRole.CONTROLLER,
53+
role=Role.CONTROLLER,
5454
id="controller_0",
5555
ip="127.0.0.1",
5656
ports={"handshake_socket": 12347, "data_status_update_socket": 12348},
@@ -61,9 +61,7 @@ async def mock_async_storage_manager():
6161
}
6262

6363
# Mock the handshake process entirely to avoid ZMQ complexity
64-
with patch(
65-
"transfer_queue.storage.managers.base.TransferQueueStorageManager._connect_to_controller"
66-
) as mock_connect:
64+
with patch("transfer_queue.storage.managers.base.StorageManager._connect_to_controller") as mock_connect:
6765
# Mock the manager without actually connecting
6866
manager = AsyncSimpleStorageManager.__new__(AsyncSimpleStorageManager)
6967
manager.storage_manager_id = "test_storage_manager"
@@ -148,7 +146,7 @@ async def test_async_storage_manager_error_handling():
148146
# Mock storage unit infos
149147
storage_unit_infos = {
150148
"storage_0": ZMQServerInfo(
151-
role=TransferQueueRole.STORAGE,
149+
role=Role.STORAGE,
152150
id="storage_0",
153151
ip="127.0.0.1",
154152
ports={"put_get_socket": 12345},
@@ -157,7 +155,7 @@ async def test_async_storage_manager_error_handling():
157155

158156
# Mock controller info
159157
controller_info = ZMQServerInfo(
160-
role=TransferQueueRole.CONTROLLER,
158+
role=Role.CONTROLLER,
161159
id="controller_0",
162160
ip="127.0.0.1",
163161
ports={"handshake_socket": 12346, "data_status_update_socket": 12347},
@@ -242,19 +240,19 @@ async def test_get_data_routes_from_hash():
242240
"""get_data should route using global_idx % num_su (hash routing)."""
243241
storage_unit_infos = {
244242
"storage_0": ZMQServerInfo(
245-
role=TransferQueueRole.STORAGE,
243+
role=Role.STORAGE,
246244
id="storage_0",
247245
ip="127.0.0.1",
248246
ports={"put_get_socket": 19010},
249247
),
250248
"storage_1": ZMQServerInfo(
251-
role=TransferQueueRole.STORAGE,
249+
role=Role.STORAGE,
252250
id="storage_1",
253251
ip="127.0.0.1",
254252
ports={"put_get_socket": 19011},
255253
),
256254
}
257-
with patch("transfer_queue.storage.managers.base.TransferQueueStorageManager._connect_to_controller"):
255+
with patch("transfer_queue.storage.managers.base.StorageManager._connect_to_controller"):
258256
manager = AsyncSimpleStorageManager.__new__(AsyncSimpleStorageManager)
259257
manager.storage_manager_id = "test_get"
260258
manager.storage_unit_infos = storage_unit_infos
@@ -295,19 +293,19 @@ async def test_clear_data_routes_from_hash():
295293
"""clear_data should route using global_idx % num_su (hash routing)."""
296294
storage_unit_infos = {
297295
"storage_0": ZMQServerInfo(
298-
role=TransferQueueRole.STORAGE,
296+
role=Role.STORAGE,
299297
id="storage_0",
300298
ip="127.0.0.1",
301299
ports={"put_get_socket": 19020},
302300
),
303301
"storage_1": ZMQServerInfo(
304-
role=TransferQueueRole.STORAGE,
302+
role=Role.STORAGE,
305303
id="storage_1",
306304
ip="127.0.0.1",
307305
ports={"put_get_socket": 19021},
308306
),
309307
}
310-
with patch("transfer_queue.storage.managers.base.TransferQueueStorageManager._connect_to_controller"):
308+
with patch("transfer_queue.storage.managers.base.StorageManager._connect_to_controller"):
311309
manager = AsyncSimpleStorageManager.__new__(AsyncSimpleStorageManager)
312310
manager.storage_manager_id = "test_clear"
313311
manager.storage_unit_infos = storage_unit_infos
@@ -346,19 +344,19 @@ async def test_hash_routing_stable_across_batch_sizes():
346344
"""
347345
storage_unit_infos = {
348346
"storage_0": ZMQServerInfo(
349-
role=TransferQueueRole.STORAGE,
347+
role=Role.STORAGE,
350348
id="storage_0",
351349
ip="127.0.0.1",
352350
ports={"put_get_socket": 19030},
353351
),
354352
"storage_1": ZMQServerInfo(
355-
role=TransferQueueRole.STORAGE,
353+
role=Role.STORAGE,
356354
id="storage_1",
357355
ip="127.0.0.1",
358356
ports={"put_get_socket": 19031},
359357
),
360358
}
361-
with patch("transfer_queue.storage.managers.base.TransferQueueStorageManager._connect_to_controller"):
359+
with patch("transfer_queue.storage.managers.base.StorageManager._connect_to_controller"):
362360
manager = AsyncSimpleStorageManager.__new__(AsyncSimpleStorageManager)
363361
manager.storage_manager_id = "test_hash_batch"
364362
manager.storage_unit_infos = storage_unit_infos
@@ -407,19 +405,19 @@ async def test_hash_routing_stable_reversed_order():
407405
"""
408406
storage_unit_infos = {
409407
"storage_0": ZMQServerInfo(
410-
role=TransferQueueRole.STORAGE,
408+
role=Role.STORAGE,
411409
id="storage_0",
412410
ip="127.0.0.1",
413411
ports={"put_get_socket": 19040},
414412
),
415413
"storage_1": ZMQServerInfo(
416-
role=TransferQueueRole.STORAGE,
414+
role=Role.STORAGE,
417415
id="storage_1",
418416
ip="127.0.0.1",
419417
ports={"put_get_socket": 19041},
420418
),
421419
}
422-
with patch("transfer_queue.storage.managers.base.TransferQueueStorageManager._connect_to_controller"):
420+
with patch("transfer_queue.storage.managers.base.StorageManager._connect_to_controller"):
423421
manager = AsyncSimpleStorageManager.__new__(AsyncSimpleStorageManager)
424422
manager.storage_manager_id = "test_hash_order"
425423
manager.storage_unit_infos = storage_unit_infos

tests/test_client.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from transfer_queue import TransferQueueClient
2626
from transfer_queue.metadata import BatchMeta
27-
from transfer_queue.utils.enum_utils import TransferQueueRole
27+
from transfer_queue.utils.enum_utils import Role
2828
from transfer_queue.utils.zmq_utils import (
2929
ZMQMessage,
3030
ZMQRequestType,
@@ -59,7 +59,7 @@ def __init__(self, controller_id="controller_0"):
5959
self.request_port = self._bind_to_random_port(self.request_socket)
6060

6161
self.zmq_server_info = ZMQServerInfo(
62-
role=TransferQueueRole.CONTROLLER,
62+
role=Role.CONTROLLER,
6363
id=controller_id,
6464
ip="127.0.0.1",
6565
ports={
@@ -300,7 +300,7 @@ def __init__(self, storage_id="storage_0"):
300300
self.data_port = self._bind_to_random_port(self.data_socket)
301301

302302
self.zmq_server_info = ZMQServerInfo(
303-
role=TransferQueueRole.STORAGE,
303+
role=Role.STORAGE,
304304
id=storage_id,
305305
ip="127.0.0.1",
306306
ports={
@@ -409,7 +409,7 @@ def client_setup(mock_controller, mock_storage):
409409

410410
# Mock the storage manager to avoid handshake issues but mock all data operations
411411
with patch(
412-
"transfer_queue.storage.managers.simple_backend_manager.AsyncSimpleStorageManager._connect_to_controller"
412+
"transfer_queue.storage.managers.simple_storage_manager.AsyncSimpleStorageManager._connect_to_controller"
413413
):
414414
config = {
415415
"controller_info": mock_controller.zmq_server_info,
@@ -502,7 +502,7 @@ def test_single_controller_multiple_storages():
502502

503503
# Mock the storage manager to avoid handshake issues but mock all data operations
504504
with patch(
505-
"transfer_queue.storage.managers.simple_backend_manager.AsyncSimpleStorageManager._connect_to_controller"
505+
"transfer_queue.storage.managers.simple_storage_manager.AsyncSimpleStorageManager._connect_to_controller"
506506
):
507507
config = {
508508
"controller_info": controller.zmq_server_info,

tests/test_kv_storage_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def test_generate_values(test_data):
106106
assert torch.equal(values[i], expected_values[i])
107107

108108

109-
@patch("transfer_queue.storage.managers.base.StorageClientFactory.create")
109+
@patch("transfer_queue.storage.managers.base.StorageKVClientFactory.create")
110110
@patch.object(KVStorageManager, "_connect_to_controller", lambda self: None)
111111
def test_merge_tensors_to_tensordict(mock_create, test_data):
112112
"""Test whether _merge_kv_to_tensordict can correctly reconstruct the TensorDict."""
@@ -268,7 +268,7 @@ def test_data_for_put_data():
268268
}
269269

270270

271-
STORAGE_CLIENT_FACTORY_PATH = "transfer_queue.storage.managers.base.StorageClientFactory"
271+
STORAGE_CLIENT_FACTORY_PATH = "transfer_queue.storage.managers.base.StorageKVClientFactory"
272272

273273

274274
@patch.object(KVStorageManager, "_connect_to_controller", lambda self: None)

tests/test_metadata.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@ class TestStorageUnitDataStrict:
678678

679679
def test_put_data_length_mismatch_raises(self):
680680
"""put_data must raise when global_indexes and field values have different lengths."""
681-
from transfer_queue.storage.simple_backend import StorageUnitData
681+
from transfer_queue.storage.simple_storage import StorageUnitData
682682

683683
sud = StorageUnitData(storage_size=10)
684684
# 3 indexes but only 2 values — must raise, not silently drop

tests/test_ray_p2p.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@
2323

2424
from transfer_queue.client import TransferQueueClient
2525
from transfer_queue.metadata import BatchMeta
26-
from transfer_queue.storage.managers.base import KVStorageManager
27-
from transfer_queue.storage.managers.factory import TransferQueueStorageManagerFactory
26+
from transfer_queue.storage.managers.base import KVStorageManager, StorageManagerFactory
2827
from transfer_queue.utils.zmq_utils import ZMQServerInfo
2928

3029
TEST_CONFIGS: list[tuple[tuple[int, int], torch.dtype]] = [
@@ -45,18 +44,18 @@
4544

4645
# Step 1: Mock Controller Role
4746
try:
48-
from transfer_queue.role import TransferQueueRole
47+
from transfer_queue.role import Role
4948
except ImportError:
5049
from enum import Enum
5150

52-
class TransferQueueRole(Enum):
51+
class Role(Enum):
5352
CONTROLLER = "controller"
5453
STORAGE = "storage"
5554

5655

5756
def create_mock_controller():
5857
return ZMQServerInfo(
59-
role=TransferQueueRole.CONTROLLER,
58+
role=Role.CONTROLLER,
6059
id="controller_0",
6160
ip="127.0.0.1",
6261
ports={
@@ -71,9 +70,9 @@ def create_mock_controller():
7170
def ensure_mock_storage_manager_registered():
7271
"""Ensure MockKVStorageManager is registered in current process."""
7372

74-
if "KV_MOCK" not in TransferQueueStorageManagerFactory._registry:
73+
if "KV_MOCK" not in StorageManagerFactory._registry:
7574

76-
@TransferQueueStorageManagerFactory.register("KV_MOCK")
75+
@StorageManagerFactory.register("KV_MOCK")
7776
class MockKVStorageManager(KVStorageManager):
7877
def _connect_to_controller(self):
7978
pass

tests/test_simple_storage_unit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import torch
2222
import zmq
2323

24-
from transfer_queue.storage.simple_backend import SimpleStorageUnit
24+
from transfer_queue.storage.simple_storage import SimpleStorageUnit
2525
from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType
2626

2727

@@ -420,7 +420,7 @@ def test_storage_unit_data_direct():
420420

421421
def test_storage_unit_data_capacity_uses_active_keys():
422422
"""Capacity check must use _active_keys, not scan field_data."""
423-
from transfer_queue.storage.simple_backend import StorageUnitData
423+
from transfer_queue.storage.simple_storage import StorageUnitData
424424

425425
storage = StorageUnitData(storage_size=3)
426426

tests/test_storage_client_factory.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import pytest
2020
import torch
2121

22-
from transfer_queue.storage.clients.factory import StorageClientFactory
22+
from transfer_queue.storage.clients.base import StorageKVClientFactory
2323
from transfer_queue.storage.clients.yuanrong_client import YuanrongStorageClient
2424

2525

@@ -29,12 +29,12 @@ def setUp(self):
2929

3030
@pytest.mark.skipif(find_spec("datasystem") is None, reason="datasystem is not available")
3131
def test_create_client(self):
32-
self.assertIn("YuanrongStorageClient", StorageClientFactory._registry)
33-
self.assertIs(StorageClientFactory._registry["YuanrongStorageClient"], YuanrongStorageClient)
34-
StorageClientFactory.create("YuanrongStorageClient", self.cfg)
32+
self.assertIn("YuanrongStorageClient", StorageKVClientFactory._registry)
33+
self.assertIs(StorageKVClientFactory._registry["YuanrongStorageClient"], YuanrongStorageClient)
34+
StorageKVClientFactory.create("YuanrongStorageClient", self.cfg)
3535

3636
with self.assertRaises(ValueError) as cm:
37-
StorageClientFactory.create("abc", self.cfg)
37+
StorageKVClientFactory.create("abc", self.cfg)
3838
self.assertIn("Unknown StorageClient", str(cm.exception))
3939

4040
@pytest.mark.skipif(
@@ -47,7 +47,7 @@ def test_client_create_empty_tensorlist(self):
4747
for t in tensors:
4848
shapes.append(t.shape)
4949
dtypes.append(t.dtype)
50-
client = StorageClientFactory.create("YuanrongStorageClient", self.cfg)
50+
client = StorageKVClientFactory.create("YuanrongStorageClient", self.cfg)
5151

5252
empty_tensors = client._create_empty_npu_tensorlist(shapes, dtypes)
5353
self.assertEqual(len(tensors), len(empty_tensors))

transfer_queue/client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
BatchMeta,
2929
)
3030
from transfer_queue.storage import (
31-
TransferQueueStorageManagerFactory,
31+
StorageManagerFactory,
3232
)
3333
from transfer_queue.utils.common import limit_pytorch_auto_parallel_threads
3434
from transfer_queue.utils.logging_utils import get_logger
@@ -92,7 +92,7 @@ def initialize_storage_manager(
9292
- zmq_info: ZMQ server information about the storage units
9393
9494
"""
95-
self.storage_manager = TransferQueueStorageManagerFactory.create(
95+
self.storage_manager = StorageManagerFactory.create(
9696
manager_type, controller_info=self._controller, config=config
9797
)
9898

0 commit comments

Comments
 (0)