Skip to content

Commit dcf3dc3

Browse files
committed
provide decorator for TransferQueueStorageManagerFactory
Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
1 parent 7f7dbed commit dcf3dc3

2 files changed

Lines changed: 13 additions & 12 deletions

File tree

transfer_queue/storage/managers/factory.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from typing import Any
1515

1616
from transfer_queue.storage.managers.base import TransferQueueStorageManager
17-
from transfer_queue.storage.managers.simple_backend_manager import AsyncSimpleStorageManager
1817

1918

2019
class TransferQueueStorageManagerFactory:
@@ -23,13 +22,17 @@ class TransferQueueStorageManagerFactory:
2322
_registry: dict[str, type[TransferQueueStorageManager]] = {}
2423

2524
@classmethod
26-
def register(cls, manager_type: str, manager_cls: type[TransferQueueStorageManager]):
27-
if not issubclass(manager_cls, TransferQueueStorageManager):
28-
raise TypeError(
29-
f"manager_cls {getattr(manager_cls, '__name__', repr(manager_cls))} must be "
30-
f"a subclass of TransferQueueStorageManager"
31-
)
32-
cls._registry[manager_type] = manager_cls
25+
def register(cls, manager_type: str):
26+
def decorator(manager_cls: type[TransferQueueStorageManager]):
27+
if not issubclass(manager_cls, TransferQueueStorageManager):
28+
raise TypeError(
29+
f"manager_cls {getattr(manager_cls, '__name__', repr(manager_cls))} must be "
30+
f"a subclass of TransferQueueStorageManager"
31+
)
32+
cls._registry[manager_type] = manager_cls
33+
return manager_cls
34+
35+
return decorator
3336

3437
@classmethod
3538
def create(cls, manager_type: str, config: dict[str, Any]) -> TransferQueueStorageManager:
@@ -38,7 +41,3 @@ def create(cls, manager_type: str, config: dict[str, Any]) -> TransferQueueStora
3841
f"Unknown manager_type: {manager_type}. Supported managers include: {list(cls._registry.keys())}"
3942
)
4043
return cls._registry[manager_type](config)
41-
42-
43-
# Register all the StorageManager
44-
TransferQueueStorageManagerFactory.register("AsyncSimpleStorageManager", AsyncSimpleStorageManager)

transfer_queue/storage/managers/simple_backend_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
from transfer_queue.metadata import BatchMeta
2828
from transfer_queue.storage.managers.base import TransferQueueStorageManager
29+
from transfer_queue.storage.managers.factory import TransferQueueStorageManagerFactory
2930
from transfer_queue.storage.simple_backend import StorageMetaGroup
3031
from transfer_queue.utils.utils import limit_pytorch_auto_parallel_threads
3132
from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType, ZMQServerInfo, create_zmq_socket
@@ -34,6 +35,7 @@
3435
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
3536

3637

38+
@TransferQueueStorageManagerFactory.register("AsyncSimpleStorageManager")
3739
class AsyncSimpleStorageManager(TransferQueueStorageManager):
3840
"""Asynchronous storage manager that handles multiple storage units.
3941

0 commit comments

Comments
 (0)