Skip to content

Commit d56e316

Browse files
committed
extract get_logger utility to reduce code duplication
Signed-off-by: ji-huazhong <hzji210@gmail.com>
1 parent 961c5df commit d56e316

20 files changed

Lines changed: 95 additions & 142 deletions

recipe/simple_use_case/single_controller_demo.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,24 @@ def compute_loss(data1, _data2):
5353

5454
def compute_reward(response_ids: torch.Tensor) -> TensorDict:
5555
"""Simulate a reward model that scores each token position in the response.
56+
Returns a TensorDict with a ``"rm_score"`` field whose shape matches
57+
``response_ids`` (i.e. one scalar per response token).
58+
"""
59+
time.sleep(1)
60+
reward = torch.randn_like(response_ids, dtype=torch.float32)
61+
62+
return TensorDict({"rm_score": reward}, batch_size=response_ids.size(0))
63+
64+
65+
def compute_advantage(rewards: torch.Tensor) -> TensorDict:
66+
"""Simulate the process of compute advantage
5667
5768
Returns a TensorDict with an ``"advantage"`` field whose shape matches
58-
``response_ids`` (i.e. one scalar per response token).
69+
``rewards`` (i.e. one scalar per reward).
5970
"""
6071
time.sleep(1)
61-
advantage = torch.randn_like(response_ids, dtype=torch.float32)
62-
return TensorDict({"advantage": advantage}, batch_size=response_ids.size(0))
72+
advantage = torch.randn_like(rewards, dtype=torch.float32)
73+
return TensorDict({"advantage": advantage}, batch_size=rewards.size(0))
6374

6475

6576
class TrainingWorker:
@@ -89,7 +100,7 @@ def infer_batch(self, kv_meta: KVBatchMeta) -> KVBatchMeta:
89100
"""Simulate forward-only inference"""
90101
# 1. Pull data from storage
91102
data = tq.kv_batch_get_by_meta(meta=kv_meta)
92-
logger.info(f"compute_log_prob: got data {data}")
103+
logger.info(f"infer_batch: got data {data}")
93104

94105
# 2. Model forward
95106
output = compute_log_prob(data["prompt_ids"], data["response_ids"])
@@ -494,6 +505,13 @@ def fit(self):
494505
meta = tq.kv_batch_put(keys=meta.keys, partition_id=meta.partition_id, fields=reward_output)
495506
logger.info(f"demo reward KVBatchMeta: {meta}")
496507

508+
# ========================= Compute advantage =========================
509+
meta.fields = ["response_ids", "ref_log_prob", "old_log_prob", "rm_score"]
510+
advantage_data = tq.kv_batch_get_by_meta(meta=meta)
511+
advantage_output = compute_advantage(advantage_data["rm_score"])
512+
meta = tq.kv_batch_put(keys=meta.keys, partition_id=meta.partition_id, fields=advantage_output)
513+
logger.info(f"demo advantage KVBatchMeta: {meta}")
514+
497515
# ========================= Update actor =========================
498516
meta.fields = [
499517
"input_ids",

transfer_queue/client.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# limitations under the License.
1515

1616
import asyncio
17-
import logging
1817
import os
1918
import threading
2019
from typing import Any, Callable, Optional
@@ -32,21 +31,15 @@
3231
TransferQueueStorageManagerFactory,
3332
)
3433
from transfer_queue.utils.common import limit_pytorch_auto_parallel_threads
34+
from transfer_queue.utils.logging_utils import get_logger
3535
from transfer_queue.utils.zmq_utils import (
3636
ZMQMessage,
3737
ZMQRequestType,
3838
ZMQServerInfo,
3939
with_zmq_socket,
4040
)
4141

42-
logger = logging.getLogger(__name__)
43-
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
44-
45-
# Ensure logger has a handler
46-
if not logger.hasHandlers():
47-
handler = logging.StreamHandler()
48-
handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s"))
49-
logger.addHandler(handler)
42+
logger = get_logger(__name__)
5043

5144
TQ_NUM_THREADS = int(os.environ.get("TQ_NUM_THREADS", 8))
5245

transfer_queue/controller.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# limitations under the License.
1515

1616
import copy
17-
import logging
1817
import os
1918
import time
2019
from collections import defaultdict
@@ -37,6 +36,7 @@
3736
)
3837
from transfer_queue.sampler import BaseSampler, SequentialSampler
3938
from transfer_queue.utils.enum_utils import TransferQueueRole
39+
from transfer_queue.utils.logging_utils import get_logger
4040
from transfer_queue.utils.perf_utils import IntervalPerfMonitor
4141
from transfer_queue.utils.zmq_utils import (
4242
ZMQMessage,
@@ -48,14 +48,7 @@
4848
get_node_ip_address_raw,
4949
)
5050

51-
logger = logging.getLogger(__name__)
52-
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
53-
54-
# Ensure logger has a handler (for Ray Actor subprocess)
55-
if not logger.hasHandlers():
56-
handler = logging.StreamHandler()
57-
handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s"))
58-
logger.addHandler(handler)
51+
logger = get_logger(__name__)
5952

6053
TQ_CONTROLLER_GET_METADATA_TIMEOUT = int(os.environ.get("TQ_CONTROLLER_GET_METADATA_TIMEOUT", 1))
6154
TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL = int(os.environ.get("TQ_CONTROLLER_GET_METADATA_CHECK_INTERVAL", 5))

transfer_queue/dataloader/streaming_dataloader.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,16 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import logging
17-
import os
1816
from typing import Optional
1917

2018
import torch
2119
from tensordict import TensorDict
2220

2321
from transfer_queue.dataloader.streaming_dataset import StreamingDataset
2422
from transfer_queue.metadata import BatchMeta
23+
from transfer_queue.utils.logging_utils import get_logger
2524

26-
logger = logging.getLogger(__name__)
27-
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
28-
29-
# Ensure logger has a handler
30-
if not logger.hasHandlers():
31-
handler = logging.StreamHandler()
32-
handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s"))
33-
logger.addHandler(handler)
25+
logger = get_logger(__name__)
3426

3527

3628
def _identity_collate_fn(data: tuple[TensorDict, BatchMeta]) -> tuple[TensorDict, BatchMeta]:

transfer_queue/dataloader/streaming_dataset.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import logging
1716
import os
1817
import time
1918
import uuid
@@ -25,19 +24,13 @@
2524

2625
from transfer_queue.client import TransferQueueClient
2726
from transfer_queue.metadata import BatchMeta
27+
from transfer_queue.utils.logging_utils import get_logger
2828

2929
TQ_STREAMING_DATASET_EMPTY_BATCH_SLEEP_INTERVAL = float(
3030
os.environ.get("TQ_STREAMING_DATASET_EMPTY_BATCH_SLEEP_INTERVAL", 1)
3131
) # in seconds
3232

33-
logger = logging.getLogger(__name__)
34-
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
35-
36-
# Ensure logger has a handler
37-
if not logger.hasHandlers():
38-
handler = logging.StreamHandler()
39-
handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s"))
40-
logger.addHandler(handler)
33+
logger = get_logger(__name__)
4134

4235

4336
class StreamingDataset(IterableDataset):

transfer_queue/interface.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import logging
1716
import math
1817
import os
1918
import subprocess
@@ -35,14 +34,14 @@
3534
from transfer_queue.sampler import BaseSampler
3635
from transfer_queue.storage.simple_backend import SimpleStorageUnit
3736
from transfer_queue.utils.common import get_placement_group
37+
from transfer_queue.utils.logging_utils import get_logger
3838
from transfer_queue.utils.yuanrong_utils import (
3939
cleanup_yuanrong_resources,
4040
initialize_yuanrong_backend,
4141
)
4242
from transfer_queue.utils.zmq_utils import process_zmq_server_info
4343

44-
logger = logging.getLogger(__name__)
45-
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
44+
logger = get_logger(__name__)
4645

4746
_TRANSFER_QUEUE_CLIENT: Any = None
4847
_TRANSFER_QUEUE_STORAGE: Any = None

transfer_queue/metadata.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616
import copy
1717
import dataclasses
1818
import itertools
19-
import logging
20-
import os
2119
from collections import defaultdict
2220
from dataclasses import dataclass
2321
from types import MappingProxyType
@@ -27,14 +25,9 @@
2725
import torch
2826
from tensordict import TensorDict
2927

30-
logger = logging.getLogger(__name__)
31-
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
28+
from transfer_queue.utils.logging_utils import get_logger
3229

33-
# Ensure logger has a handler
34-
if not logger.hasHandlers():
35-
handler = logging.StreamHandler()
36-
handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s"))
37-
logger.addHandler(handler)
30+
logger = get_logger(__name__)
3831

3932

4033
# ---------------------------------------------------------------------------

transfer_queue/storage/clients/mooncake_client.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import logging
17-
import os
1816
import pickle
1917
from typing import Any, Optional
2018

@@ -23,9 +21,9 @@
2321

2422
from transfer_queue.storage.clients.base import TransferQueueStorageKVClient
2523
from transfer_queue.storage.clients.factory import StorageClientFactory
24+
from transfer_queue.utils.logging_utils import get_logger
2625

27-
logger = logging.getLogger(__name__)
28-
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
26+
logger = get_logger(__name__)
2927

3028
MOONCAKE_STORE_IMPORTED: bool = True
3129
try:

transfer_queue/storage/clients/yuanrong_client.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
import logging
17-
import os
1816
import struct
1917
from abc import ABC, abstractmethod
2018
from concurrent.futures import ThreadPoolExecutor
@@ -25,11 +23,11 @@
2523

2624
from transfer_queue.storage.clients.base import TransferQueueStorageKVClient
2725
from transfer_queue.storage.clients.factory import StorageClientFactory
26+
from transfer_queue.utils.logging_utils import get_logger
2827
from transfer_queue.utils.serial_utils import _decoder, _encoder
2928
from transfer_queue.utils.yuanrong_utils import find_reachable_host
3029

31-
logger = logging.getLogger(__name__)
32-
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
30+
logger = get_logger(__name__)
3331

3432

3533
YUANRONG_DATASYSTEM_IMPORTED: bool = True

transfer_queue/storage/managers/base.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
import asyncio
1717
import itertools
18-
import logging
1918
import os
2019
import time
2120
import weakref
@@ -34,16 +33,10 @@
3433

3534
from transfer_queue.metadata import BatchMeta, extract_field_schema
3635
from transfer_queue.storage.clients.factory import StorageClientFactory
36+
from transfer_queue.utils.logging_utils import get_logger
3737
from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType, ZMQServerInfo, create_zmq_socket
3838

39-
logger = logging.getLogger(__name__)
40-
logger.setLevel(os.getenv("TQ_LOGGING_LEVEL", logging.WARNING))
41-
42-
# Ensure logger has a handler
43-
if not logger.hasHandlers():
44-
handler = logging.StreamHandler()
45-
handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s"))
46-
logger.addHandler(handler)
39+
logger = get_logger(__name__)
4740

4841
# ZMQ timeouts (in seconds) and retry configurations
4942
TQ_STORAGE_POLLER_TIMEOUT = int(os.environ.get("TQ_STORAGE_POLLER_TIMEOUT", 5))

0 commit comments

Comments
 (0)