Skip to content

Commit b9cb614

Browse files
committed
[client, storage] refactor: unify dynamic ZMQ socket decorator between simple_backend_manager and client (#66)
as title. --------- Signed-off-by: ji-huazhong <hzji210@gmail.com>
1 parent aec9192 commit b9cb614

20 files changed

Lines changed: 209 additions & 287 deletions

recipe/simple_use_case/single_controller_demo.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,24 @@ def compute_loss(data1, _data2):
5353

5454
def compute_reward(response_ids: torch.Tensor) -> TensorDict:
5555
"""Simulate a reward model that scores each token position in the response.
56+
Returns a TensorDict with a ``"rm_score"`` field whose shape matches
57+
``response_ids`` (i.e. one scalar per response token).
58+
"""
59+
time.sleep(1)
60+
reward = torch.randn_like(response_ids, dtype=torch.float32)
61+
62+
return TensorDict({"rm_score": reward}, batch_size=response_ids.size(0))
63+
64+
65+
def compute_advantage(rewards: torch.Tensor) -> TensorDict:
66+
"""Simulate the process of compute advantage
5667
5768
Returns a TensorDict with an ``"advantage"`` field whose shape matches
58-
``response_ids`` (i.e. one scalar per response token).
69+
``rewards`` (i.e. one scalar per reward).
5970
"""
6071
time.sleep(1)
61-
advantage = torch.randn_like(response_ids, dtype=torch.float32)
62-
return TensorDict({"advantage": advantage}, batch_size=response_ids.size(0))
72+
advantage = torch.randn_like(rewards, dtype=torch.float32)
73+
return TensorDict({"advantage": advantage}, batch_size=rewards.size(0))
6374

6475

6576
class TrainingWorker:
@@ -89,7 +100,7 @@ def infer_batch(self, kv_meta: KVBatchMeta) -> KVBatchMeta:
89100
"""Simulate forward-only inference"""
90101
# 1. Pull data from storage
91102
data = tq.kv_batch_get_by_meta(meta=kv_meta)
92-
logger.info(f"compute_log_prob: got data {data}")
103+
logger.info(f"infer_batch: got data {data}")
93104

94105
# 2. Model forward
95106
output = compute_log_prob(data["prompt_ids"], data["response_ids"])
@@ -494,6 +505,13 @@ def fit(self):
494505
meta = tq.kv_batch_put(keys=meta.keys, partition_id=meta.partition_id, fields=reward_output)
495506
logger.info(f"demo reward KVBatchMeta: {meta}")
496507

508+
# ========================= Compute advantage =========================
509+
meta.fields = ["response_ids", "ref_log_prob", "old_log_prob","rm_score"]
510+
advantage_data = tq.kv_batch_get_by_meta(meta=meta)
511+
advantage_output = compute_advantage(advantage_data["rm_score"])
512+
meta = tq.kv_batch_put(keys=meta.keys, partition_id=meta.partition_id, fields=advantage_output)
513+
logger.info(f"demo advantage KVBatchMeta: {meta}")
514+
497515
# ========================= Update actor =========================
498516
meta.fields = [
499517
"input_ids",

transfer_queue/client.py

Lines changed: 22 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,9 @@
1414
# limitations under the License.
1515

1616
import asyncio
17-
import logging
1817
import os
1918
import threading
20-
from functools import wraps
2119
from typing import Any, Callable, Optional
22-
from uuid import uuid4
2320

2421
import torch
2522
import zmq
@@ -34,25 +31,25 @@
3431
TransferQueueStorageManagerFactory,
3532
)
3633
from transfer_queue.utils.common import limit_pytorch_auto_parallel_threads
34+
from transfer_queue.utils.logging_utils import get_logger
3735
from transfer_queue.utils.zmq_utils import (
3836
ZMQMessage,
3937
ZMQRequestType,
4038
ZMQServerInfo,
41-
create_zmq_socket,
42-
format_zmq_address,
39+
with_zmq_socket,
4340
)
4441

45-
logger = logging.getLogger(__name__)
46-
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
47-
48-
# Ensure logger has a handler
49-
if not logger.hasHandlers():
50-
handler = logging.StreamHandler()
51-
handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s"))
52-
logger.addHandler(handler)
42+
logger = get_logger(__name__)
5343

5444
TQ_NUM_THREADS = int(os.environ.get("TQ_NUM_THREADS", 8))
5545

46+
# Pre-bound decorator for controller socket operations.
47+
with_controller_socket = with_zmq_socket(
48+
"request_handle_socket",
49+
get_identity=lambda self: self.client_id,
50+
get_peer=lambda self, target: self._controller,
51+
)
52+
5653

5754
class AsyncTransferQueueClient:
5855
"""Asynchronous client for interacting with TransferQueue controller and storage systems.
@@ -99,63 +96,8 @@ def initialize_storage_manager(
9996
manager_type, controller_info=self._controller, config=config
10097
)
10198

102-
# TODO (TQStorage): Provide a general dynamic socket function for both Client & Storage @huazhong.
103-
@staticmethod
104-
def dynamic_socket(socket_name: str):
105-
"""Decorator to auto-manage ZMQ sockets for Controller/Storage servers.
106-
107-
Handles socket lifecycle: create -> connect -> inject -> close.
108-
109-
Args:
110-
socket_name: Port name from server config to use for ZMQ connection (e.g., "data_req_port")
111-
112-
Decorated Function Requirements:
113-
1. Must be an async class method (needs `self`)
114-
2. `self` must have:
115-
- `_controller`: Server registry
116-
- `client_id`: Unique client ID for socket identity
117-
3. Receives ZMQ socket via `socket` keyword argument (injected by decorator)
118-
"""
119-
120-
def decorator(func: Callable):
121-
@wraps(func)
122-
async def wrapper(self, *args, **kwargs):
123-
server_info = self._controller
124-
if not server_info:
125-
raise RuntimeError("No controller registered")
126-
127-
context = zmq.asyncio.Context()
128-
address = format_zmq_address(server_info.ip, server_info.ports.get(socket_name))
129-
identity = f"{self.client_id}_to_{server_info.id}_{uuid4().hex[:8]}".encode()
130-
sock = create_zmq_socket(context, zmq.DEALER, identity=identity, ip=server_info.ip)
131-
132-
try:
133-
sock.connect(address)
134-
logger.debug(
135-
f"[{self.client_id}]: Connected to Controller {server_info.id} at {address} "
136-
f"with identity {identity.decode()}"
137-
)
138-
139-
kwargs["socket"] = sock
140-
return await func(self, *args, **kwargs)
141-
except Exception as e:
142-
logger.error(f"[{self.client_id}]: Error in socket operation with Controller {server_info.id}: {e}")
143-
raise
144-
finally:
145-
try:
146-
if not sock.closed:
147-
sock.close(linger=-1)
148-
except Exception as e:
149-
logger.warning(f"[{self.client_id}]: Error closing socket to Controller {server_info.id}: {e}")
150-
151-
context.term()
152-
153-
return wrapper
154-
155-
return decorator
156-
15799
# ==================== Basic API ====================
158-
@dynamic_socket(socket_name="request_handle_socket")
100+
@with_controller_socket
159101
async def async_get_meta(
160102
self,
161103
data_fields: list[str],
@@ -245,7 +187,7 @@ async def async_get_meta(
245187
f"{response_msg.body.get('message', 'Unknown error')}"
246188
)
247189

248-
@dynamic_socket(socket_name="request_handle_socket")
190+
@with_controller_socket
249191
async def async_set_custom_meta(
250192
self,
251193
metadata: BatchMeta,
@@ -545,7 +487,7 @@ async def async_clear_samples(self, metadata: BatchMeta):
545487
except Exception as e:
546488
raise RuntimeError(f"Error in clear_samples operation: {str(e)}") from e
547489

548-
@dynamic_socket(socket_name="request_handle_socket")
490+
@with_controller_socket
549491
async def _clear_meta_in_controller(self, metadata: BatchMeta, socket=None):
550492
"""Clear metadata in the controller.
551493
@@ -571,7 +513,7 @@ async def _clear_meta_in_controller(self, metadata: BatchMeta, socket=None):
571513
if response_msg.request_type != ZMQRequestType.CLEAR_META_RESPONSE:
572514
raise RuntimeError("Failed to clear samples metadata in controller.")
573515

574-
@dynamic_socket(socket_name="request_handle_socket")
516+
@with_controller_socket
575517
async def _get_partition_meta(self, partition_id: str, socket=None) -> BatchMeta:
576518
"""Get metadata required for the whole partition from controller.
577519
@@ -601,7 +543,7 @@ async def _get_partition_meta(self, partition_id: str, socket=None) -> BatchMeta
601543

602544
return response_msg.body["metadata"]
603545

604-
@dynamic_socket(socket_name="request_handle_socket")
546+
@with_controller_socket
605547
async def _clear_partition_in_controller(self, partition_id, socket=None):
606548
"""Clear the whole partition in the controller.
607549
@@ -628,7 +570,7 @@ async def _clear_partition_in_controller(self, partition_id, socket=None):
628570
raise RuntimeError(f"Failed to clear partition {partition_id} in controller.")
629571

630572
# ==================== Status Query API ====================
631-
@dynamic_socket(socket_name="request_handle_socket")
573+
@with_controller_socket
632574
async def async_get_consumption_status(
633575
self,
634576
task_name: str,
@@ -691,7 +633,7 @@ async def async_get_consumption_status(
691633
except Exception as e:
692634
raise RuntimeError(f"[{self.client_id}]: Error in get_consumption_status: {str(e)}") from e
693635

694-
@dynamic_socket(socket_name="request_handle_socket")
636+
@with_controller_socket
695637
async def async_get_production_status(
696638
self,
697639
data_fields: list[str],
@@ -823,7 +765,7 @@ async def async_check_production_status(
823765
return False
824766
return torch.all(production_status == 1).item()
825767

826-
@dynamic_socket(socket_name="request_handle_socket")
768+
@with_controller_socket
827769
async def async_reset_consumption(
828770
self,
829771
partition_id: str,
@@ -885,7 +827,7 @@ async def async_reset_consumption(
885827
except Exception as e:
886828
raise RuntimeError(f"[{self.client_id}]: Error in reset_consumption: {str(e)}") from e
887829

888-
@dynamic_socket(socket_name="request_handle_socket")
830+
@with_controller_socket
889831
async def async_get_partition_list(
890832
self,
891833
socket: Optional[zmq.asyncio.Socket] = None,
@@ -931,7 +873,7 @@ async def async_get_partition_list(
931873
raise RuntimeError(f"[{self.client_id}]: Error in get_partition_list: {str(e)}") from e
932874

933875
# ==================== KV Interface API ====================
934-
@dynamic_socket(socket_name="request_handle_socket")
876+
@with_controller_socket
935877
async def async_kv_retrieve_meta(
936878
self,
937879
keys: list[str] | str,
@@ -997,7 +939,7 @@ async def async_kv_retrieve_meta(
997939
except Exception as e:
998940
raise RuntimeError(f"[{self.client_id}]: Error in kv_retrieve_keys: {str(e)}") from e
999941

1000-
@dynamic_socket(socket_name="request_handle_socket")
942+
@with_controller_socket
1001943
async def async_kv_retrieve_keys(
1002944
self,
1003945
global_indexes: list[int] | int,
@@ -1060,7 +1002,7 @@ async def async_kv_retrieve_keys(
10601002
except Exception as e:
10611003
raise RuntimeError(f"[{self.client_id}]: Error in kv_retrieve_indexes: {str(e)}") from e
10621004

1063-
@dynamic_socket(socket_name="request_handle_socket")
1005+
@with_controller_socket
10641006
async def async_kv_list(
10651007
self,
10661008
partition_id: Optional[str] = None,

transfer_queue/controller.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# limitations under the License.
1515

1616
import copy
17-
import logging
1817
import os
1918
import time
2019
from collections import defaultdict
@@ -37,6 +36,7 @@
3736
)
3837
from transfer_queue.sampler import BaseSampler, SequentialSampler
3938
from transfer_queue.utils.enum_utils import TransferQueueRole
39+
from transfer_queue.utils.logging_utils import get_logger
4040
from transfer_queue.utils.perf_utils import IntervalPerfMonitor
4141
from transfer_queue.utils.zmq_utils import (
4242
ZMQMessage,
@@ -48,14 +48,7 @@
4848
get_node_ip_address_raw,
4949
)
5050

51-
logger = logging.getLogger(__name__)
52-
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
53-
54-
# Ensure logger has a handler (for Ray Actor subprocess)
55-
if not logger.hasHandlers():
56-
handler = logging.StreamHandler()
57-
handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s"))
58-
logger.addHandler(handler)
51+
logger = get_logger(__name__)
5952

6053
TQ_CONTROLLER_GET_METADATA_TIMEOUT = int(os.environ.get("TQ_CONTROLLER_GET_METADATA_TIMEOUT", 1))
6154
TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL = int(os.environ.get("TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL", 5))

transfer_queue/dataloader/streaming_dataloader.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import logging
1716
import os
1817
from typing import Optional
1918

@@ -22,15 +21,9 @@
2221

2322
from transfer_queue.dataloader.streaming_dataset import StreamingDataset
2423
from transfer_queue.metadata import BatchMeta
24+
from transfer_queue.utils.logging_utils import get_logger
2525

26-
logger = logging.getLogger(__name__)
27-
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
28-
29-
# Ensure logger has a handler
30-
if not logger.hasHandlers():
31-
handler = logging.StreamHandler()
32-
handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s"))
33-
logger.addHandler(handler)
26+
logger = get_logger(__name__)
3427

3528

3629
def _identity_collate_fn(data: tuple[TensorDict, BatchMeta]) -> tuple[TensorDict, BatchMeta]:

transfer_queue/dataloader/streaming_dataset.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import logging
1716
import os
1817
import time
1918
import uuid
@@ -25,19 +24,13 @@
2524

2625
from transfer_queue.client import TransferQueueClient
2726
from transfer_queue.metadata import BatchMeta
27+
from transfer_queue.utils.logging_utils import get_logger
2828

2929
TQ_STREAMING_DATASET_EMPTY_BATCH_SLEEP_INTERVAL = float(
3030
os.environ.get("TQ_STREAMING_DATASET_EMPTY_BATCH_SLEEP_INTERVAL", 1)
3131
) # in seconds
3232

33-
logger = logging.getLogger(__name__)
34-
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
35-
36-
# Ensure logger has a handler
37-
if not logger.hasHandlers():
38-
handler = logging.StreamHandler()
39-
handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s"))
40-
logger.addHandler(handler)
33+
logger = get_logger(__name__)
4134

4235

4336
class StreamingDataset(IterableDataset):

transfer_queue/interface.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,11 @@
3939
cleanup_yuanrong_resources,
4040
initialize_yuanrong_backend,
4141
)
42+
from transfer_queue.utils.logging_utils import get_logger
4243
from transfer_queue.utils.zmq_utils import process_zmq_server_info
4344

44-
logger = logging.getLogger(__name__)
45-
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
45+
46+
logger = get_logger(__name__)
4647

4748
_TRANSFER_QUEUE_CLIENT: Any = None
4849
_TRANSFER_QUEUE_STORAGE: Any = None

transfer_queue/metadata.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import copy
1717
import dataclasses
1818
import itertools
19-
import logging
2019
import os
2120
from collections import defaultdict
2221
from dataclasses import dataclass
@@ -27,14 +26,9 @@
2726
import torch
2827
from tensordict import TensorDict
2928

30-
logger = logging.getLogger(__name__)
31-
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
29+
from transfer_queue.utils.logging_utils import get_logger
3230

33-
# Ensure logger has a handler
34-
if not logger.hasHandlers():
35-
handler = logging.StreamHandler()
36-
handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s"))
37-
logger.addHandler(handler)
31+
logger = get_logger(__name__)
3832

3933

4034
# ---------------------------------------------------------------------------

0 commit comments

Comments
 (0)