Skip to content

Commit 7d5e623

Browse files
committed
change print to logger.log
1 parent aa511a7 commit 7d5e623

File tree

14 files changed

+86
-175
lines changed

14 files changed

+86
-175
lines changed

ajet/backbone/main_verl.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import hydra
2323
import ray
2424
from beast_logger import print_dict
25+
from loguru import logger
2526
from omegaconf import OmegaConf
2627
from verl.trainer.ppo.reward import load_reward_manager
2728
from verl.utils.device import is_cuda_available
@@ -112,7 +113,7 @@ def run(self, config):
112113
from omegaconf import OmegaConf
113114
from verl.utils.fs import copy_to_local
114115

115-
print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}")
116+
logger.info(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}")
116117
pprint(OmegaConf.to_container(config, resolve=True))
117118
OmegaConf.resolve(config)
118119

@@ -148,8 +149,6 @@ def run(self, config):
148149
from verl.workers.fsdp_workers import CriticWorker
149150
elif use_legacy_worker_impl == "disable":
150151
from verl.workers.roles import CriticWorker
151-
152-
print("Using new worker implementation")
153152
else:
154153
raise ValueError(f"Invalid use_legacy_worker_impl: {use_legacy_worker_impl}")
155154

ajet/backbone/main_vllm.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from ajet.utils.launch_utils import set_loguru_default_color
1111
from ajet.schema.logprob import TokenAndProb
1212
from ajet.utils.core_env_vars import get_runtime_env
13+
from loguru import logger
1314

1415
set_loguru_default_color()
1516

@@ -116,12 +117,11 @@ def run(config):
116117
config.ajet.task_reader,
117118
)
118119
tasks = task_reader.get_validation_tasks()
119-
print(tasks[:2])
120+
logger.info(tasks[:n_task])
120121
ctx_tracker = parallel_env.rollout(
121122
tasks=tasks[:n_task], mode="sample", epoch="1"
122123
) # "sample" or "validate"
123124
_ = parallel_env.to_dataproto(ctx_tracker)
124-
print("Generated batch output")
125125

126126

127127
@hydra.main(
@@ -133,7 +133,6 @@ def main(config):
133133
from omegaconf import OmegaConf
134134

135135
OmegaConf.resolve(config)
136-
print("*" * 20)
137136

138137
runtime_env = get_runtime_env()
139138
os.environ.update(runtime_env["env_vars"])
@@ -147,12 +146,12 @@ def companion_launch():
147146

148147
from ajet.utils.smart_daemon import LaunchCommandWhenAbsent
149148

150-
print("Launching companion process for async LLM server...")
149+
logger.info("Launching companion process for async LLM server...")
151150
model_path = config.ajet.model.path
152151
tensor_parallel_size = config.ajet.debug.debug_tensor_parallel_size
153152
n_avail_gpus = torch.cuda.device_count()
154153
if tensor_parallel_size > n_avail_gpus:
155-
print(
154+
logger.info(
156155
f"Warning: tensor_parallel_size {tensor_parallel_size} is greater than available GPUs {n_avail_gpus}. Setting tensor_parallel_size to {n_avail_gpus}."
157156
)
158157
tensor_parallel_size = n_avail_gpus

ajet/backbone/trainer_verl.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -293,15 +293,15 @@ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str):
293293
)
294294

295295
if self.config.algorithm.use_kl_in_reward and config.actor_rollout_ref.actor.use_kl_loss:
296-
print("NOTICE: You have both enabled in-reward kl and kl loss.")
296+
logger.warning("NOTICE: You have both enabled in-reward kl and kl loss.")
297297

298298
# critic
299299
if self.use_critic:
300300
critic_config = omega_conf_to_dataclass(config.critic)
301301
critic_config.validate(n_gpus, config.ajet.data.train_batch_size)
302302

303303
if config.data.get("val_batch_size", None) is not None:
304-
print(
304+
logger.warning(
305305
"WARNING: val_batch_size is deprecated."
306306
+ " Validation datasets are sent to inference engines as a whole batch,"
307307
+ " which will schedule the memory themselves."
@@ -313,7 +313,7 @@ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str):
313313
config.ajet.rollout.temperature > 0
314314
), "validation gen temperature should be greater than 0 when enabling do_sample"
315315

316-
print("[validate_config] All configuration checks passed successfully!")
316+
logger.success("[validate_config] All configuration checks passed successfully!")
317317

318318
def init_workers(self):
319319
"""Initialize distributed training workers using Ray backend.
@@ -787,7 +787,7 @@ def fit(self): # noqa: C901
787787
or esi_close_to_expiration
788788
):
789789
if esi_close_to_expiration:
790-
print("Force saving checkpoint: ESI instance expiration approaching.")
790+
logger.info("Force saving checkpoint: ESI instance expiration approaching.")
791791
with marked_timer("save_checkpoint", timing_raw, color="green"):
792792
self._save_checkpoint()
793793

ajet/context_tracker/basic_tracker.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1+
import torch
12
import copy
23
from collections import defaultdict
34
from typing import List, Tuple
4-
5-
import torch
5+
from loguru import logger
66

77
from ajet.context_tracker.base_tracker import (
88
BaseTracker,
@@ -231,7 +231,7 @@ def group_tokenize_multi_group(self):
231231
sample_arr += [sample]
232232

233233
if len(sample_arr) > max_num_group:
234-
print(f"Warning: allow {max_num_group} groups, but got {len(sample_arr)} groups")
234+
logger.warning(f"Warning: allow {max_num_group} groups, but got {len(sample_arr)} groups")
235235
import random
236236

237237
sample_arr = random.sample(sample_arr, max_num_group) # preserve max_num_group groups

ajet/default_config/ajet_default.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ ajet:
77

88

99
# the experimental reverse proxy feature that allows `tuner.as_oai_baseurl_apikey` feature
10-
enable_experimental_reverse_proxy: True
10+
enable_experimental_reverse_proxy: False
1111

1212
model:
1313
# which model should be trained

ajet/schema/convertion.py

Lines changed: 3 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ def convert_llm_proxy_response_to_oai_response(llm_proxy_response):
4646
usage=usage,
4747
)
4848

49-
# copied from AgentScope's DashScopeChatModule
49+
50+
51+
# modified from AgentScope's DashScopeChatModule
5052
def convert_llm_proxy_response_to_agentscope_response(
5153
message,
5254
structured_model: Type[BaseModel] | None = None,
@@ -105,91 +107,3 @@ def convert_llm_proxy_response_to_agentscope_response(
105107

106108
return parsed_response
107109

108-
109-
110-
def test_convert_llm_proxy_response_to_oai_response():
111-
"""Test the conversion from llm_proxy_response to OpenAI ChatCompletion format."""
112-
113-
from ajet.schema.logprob import TokenAndProb
114-
# Test case 1: Basic response with content only
115-
llm_proxy_response_basic = {
116-
"role": "assistant",
117-
"request_id": "req-123456",
118-
"content": "Hello, how can I help you today?",
119-
"tool_calls": None,
120-
"tokens": [
121-
TokenAndProb(
122-
token_id=123,
123-
logprob=-0.5,
124-
decoded_string="Hello",
125-
),
126-
TokenAndProb(
127-
token_id=456,
128-
logprob=-0.3,
129-
decoded_string=",",
130-
),
131-
],
132-
}
133-
134-
result = convert_llm_proxy_response_to_oai_response(llm_proxy_response_basic)
135-
136-
assert result.id == "req-123456"
137-
assert result.object == "chat.completion"
138-
assert len(result.choices) == 1
139-
assert result.choices[0].message.role == "assistant"
140-
assert result.choices[0].message.content == "Hello, how can I help you today?"
141-
assert result.choices[0].message.tool_calls is None
142-
assert result.choices[0].finish_reason == "stop"
143-
assert result.usage is not None
144-
assert result.usage.completion_tokens == 2
145-
assert result.usage.total_tokens == 2
146-
147-
print("✓ Test case 1 passed: Basic response with content")
148-
149-
# Test case 2: Response with tool calls
150-
llm_proxy_response_with_tools = {
151-
"role": "assistant",
152-
"request_id": "req-789012",
153-
"content": "",
154-
"tool_calls": [
155-
{
156-
"id": "call_abc123",
157-
"type": "function",
158-
"function": {
159-
"name": "get_weather",
160-
"arguments": '{"location": "San Francisco"}'
161-
}
162-
}
163-
],
164-
"tokens": [],
165-
}
166-
167-
result2 = convert_llm_proxy_response_to_oai_response(llm_proxy_response_with_tools)
168-
169-
assert result2.id == "req-789012"
170-
assert result2.choices[0].message.content == ""
171-
assert result2.choices[0].message.tool_calls is not None
172-
assert len(result2.choices[0].message.tool_calls) == 1
173-
assert result2.usage is None # No tokens provided
174-
175-
print("✓ Test case 2 passed: Response with tool calls")
176-
177-
# Test case 3: Minimal response with defaults
178-
llm_proxy_response_minimal = {
179-
"content": "Test response"
180-
}
181-
182-
result3 = convert_llm_proxy_response_to_oai_response(llm_proxy_response_minimal)
183-
184-
assert result3.id == "chatcmpl-default"
185-
assert result3.choices[0].message.role == "assistant"
186-
assert result3.choices[0].message.content == "Test response"
187-
assert result3.model == "unknown"
188-
189-
print("✓ Test case 3 passed: Minimal response with defaults")
190-
191-
print("\n✅ All tests passed!")
192-
193-
194-
if __name__ == "__main__":
195-
test_convert_llm_proxy_response_to_oai_response()

ajet/schema/trajectory.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import numpy as np
44
from pydantic import BaseModel, Field
5+
from loguru import logger
56

67

78
class Reward(BaseModel):
@@ -31,7 +32,6 @@ def performance_reward(self):
3132
# this reward is NOT used in training
3233
if (self.step_reward_arr is not None) and len(self.step_reward_arr) > 0:
3334
res = np.mean(self.step_reward_arr)
34-
# print(f"Performance reward computed as mean of step_reward_arr: {res}")
3535
return res
3636
else:
3737
return self.raw_reward
@@ -146,13 +146,13 @@ def truncate_output_ids(self) -> None:
146146

147147
if len(self.response_ids) > self.max_response_len:
148148
truncate_any = True
149-
print(
149+
logger.warning(
150150
"-------------------------------------------------------------------------------------------------------"
151151
)
152-
print(
152+
logger.warning(
153153
f"Warning: response_ids length {len(self.response_ids)} exceeds max_response_len {self.max_response_len}, truncating."
154154
)
155-
print(
155+
logger.warning(
156156
"-------------------------------------------------------------------------------------------------------"
157157
)
158158
self.response_ids = self.response_ids[: self.max_response_len]

ajet/task_reader/tracing_reader/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ def __init__(
3131

3232
super().__init__(reader_config)
3333
# config patch
34-
# print("*********", config, "**********")
3534
self.reader_config = reader_config.feedback_tracing
3635

3736
logger.info(

ajet/task_rollout/dashscope_llm_bridge.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ def external_llm_chat_fn(messages, sampling_params_override={}, request_id=""):
7171
return {"role": message["role"], "content": message["content"]}
7272
except Exception as e:
7373
logger.bind(exception=True).exception(f"Error calling alien llm: {e}")
74+
logger.warning(f"Error calling alien llm: {e}, retrying...")
7475
time.sleep(5)
75-
print(f"Error calling alien llm: {e}, retrying...")
7676
raise RuntimeError(f"Failed to get response from alien llm after {max_try} attempts")
7777

7878
return external_llm_chat_fn

ajet/tuner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def as_oai_baseurl_apikey(
101101
```
102102
"""
103103

104+
assert self.config.ajet.enable_experimental_reverse_proxy, "Please enable `ajet.enable_experimental_reverse_proxy` in yaml config to use `as_oai_baseurl_apikey` feature."
104105
baseurl_apikey_model = OpenaiClientBaseUrlTuner(
105106
config=self.config,
106107
context_tracker=self.context_tracker,

0 commit comments

Comments
 (0)