Skip to content

Commit 1d84f12

Browse files
committed
Enhance configuration and logging features; add server experiment directory retrieval and log empty content messages
1 parent db9f245 commit 1d84f12

8 files changed

Lines changed: 57 additions & 6 deletions

File tree

ajet/default_config/ajet_default.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,3 +356,9 @@ ajet:
356356
# DO NOT EDIT, FOR ROBOT TESTING PURPOSE ONLY. NOT FOR HUMAN.
357357
execute_test: False # DO NOT EDIT, FOR ROBOT TESTING PURPOSE ONLY. NOT FOR HUMAN.
358358
execute_testing_lambda: "" # DO NOT EDIT, FOR ROBOT TESTING PURPOSE ONLY. NOT FOR HUMAN.
359+
360+
361+
# ------------------ hydra runtime ------------------
362+
hydra:
363+
run:
364+
dir: saved_experiments/hydra_outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}

ajet/tuner_lib/experimental/oai_model_server.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from openai.types.chat.chat_completion_chunk import ChoiceDelta, ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction
4040

4141
from ajet.utils.networking import get_host_ip
42+
from ajet.utils.message_utils import log_empty_content_messages
4243
from ajet.tuner_lib.experimental.interchange_utils import EpisodeStatus
4344
from ajet.tuner_lib.experimental.interchange_utils import DEBUG, VERBOSE, API_KEY_PREFIX
4445

@@ -288,6 +289,9 @@ async def chat_completions(request: Request, authorization: str = Header(None)):
288289
logger.warning(f"First message role is '{first_msg.get('role')}', expected 'system'. Adding default system prompt.")
289290
new_req.messages.insert(0, {"role": "system", "content": "You are a helpful assistant, your name is AgentJet."})
290291

292+
# Detect empty-content messages in the inbound request
293+
log_empty_content_messages(new_req.messages, episode_uuid=episode_uuid)
294+
291295
# Create timeline UUID
292296
timeline_uuid = uuid.uuid4().hex
293297

ajet/tuner_lib/experimental/swarm_client.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -747,6 +747,21 @@ def stop_engine(self):
747747
raise RuntimeError("Failed to stop training engine")
748748
self._wait_until_status_change_to(desired_status="ENGINE.OFFLINE")
749749

750+
def server_experiment_dir(self) -> str:
751+
"""
752+
Fetch the absolute experiment directory from the Swarm server.
753+
Returns None if the engine has not started yet (no experiment dir is set).
754+
"""
755+
try:
756+
resp = self._http_client.get(
757+
f"{self.server_url}/get_server_experiment_dir",
758+
timeout=10
759+
)
760+
raise_for_status_with_detail(resp)
761+
return resp.json().get("server_experiment_dir", None)
762+
except Exception as e:
763+
return "saved_experiments"
764+
750765
def get_rollout_stat(self) -> CurrentBatchRolloutPoolInformation:
751766
"""
752767
Get the current batch rollout pool information from the Swarm server.

ajet/tuner_lib/experimental/swarm_server.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,7 @@ def override_param_callback(config):
389389
backbone=backbone,
390390
override_param_callback=override_param_callback,
391391
)
392+
shared_mem_dict["server_experiment_dir"] = exe_exp_base
392393

393394
# Setup environment variables
394395
env, exp_config = setup_environment_vars(args, exp_config, main_yaml_fp)
@@ -491,6 +492,11 @@ async def get_engine_status():
491492
"global_step": global_step,
492493
}
493494

495+
@app.get("/get_server_experiment_dir")
496+
async def get_server_experiment_dir():
497+
"""Return the absolute experiment directory once the engine has started."""
498+
return {"server_experiment_dir": shared_mem_dict.get("server_experiment_dir", None)}
499+
494500
# --- episode status ---
495501
@app.post("/register_episode", response_model=BoolResponse)
496502
async def register_episode(req: RegisterEpisodeRequest):

ajet/utils/core_env_vars.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def get_runtime_env(config, is_trinity: bool = False) -> dict:
4242
"TRINITY_PLUGIN_DIRS": str((Path(__file__).parent.parent / "backbone").resolve()),
4343
# "VLLM_ALLOW_RUNTIME_LORA_UPDATING": "true",
4444
"SWANLAB_API_KEY": os.getenv("SWANLAB_API_KEY", ""),
45+
"SWANLAB_LOG_DIR": os.getenv("SWANLAB_LOG_DIR", "saved_experiments/swanlog"),
4546
"AJET_CONFIG_REDIRECT": os.getenv("AJET_CONFIG_REDIRECT", ""),
4647
"AJET_DAT_INTERCHANGE_PORT": os.getenv("AJET_DAT_INTERCHANGE_PORT", data_interchange_port),
4748
"MASTER_NODE_IP": os.getenv("MASTER_NODE_IP", master_node_ip),

ajet/utils/message_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,23 @@
11
import copy
22
from typing import Dict, List
33

4+
from loguru import logger
5+
6+
7+
def log_empty_content_messages(messages: List[Dict], episode_uuid: str = "") -> None:
8+
"""Scan an OpenAI-compatible message list and log an error for any message
9+
whose content is empty/None and which carries no tool_calls.
10+
"""
11+
for idx, m in enumerate(messages or []):
12+
content = m.get("content")
13+
tool_calls = m.get("tool_calls") or []
14+
if content in (None, "") and not tool_calls:
15+
logger.error(
16+
f"[{episode_uuid}] Empty content in inbound message "
17+
f"index={idx} role={m.get('role')} tool_call_id={m.get('tool_call_id')!r} "
18+
f"content={content!r} tool_calls={tool_calls}"
19+
)
20+
421

522
# apply chat_template to a message, and then convert back to message
623
def convert_tool_to_user_message(tool_message, tokenizer, format="qwen"):

tutorial/opencode_build_aime/agent_roll_v3.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,13 @@
2424

2525
REMOTE_MODEL_PATH = os.getenv("REMOTE_MODEL_PATH", "/mnt/data_cpfs/xielipeng.xlp/models/Qwen3-14B")
2626
BATCH_SIZE = 16
27-
PPO_EPOCH = 4
27+
PPO_EPOCH = 2
2828
NUM_REPEAT = 8
29-
MINI_BATCH_NUM = 1
29+
MINI_BATCH_NUM = 2
3030
ajet_job = AgentJetJob(
31+
ensure_new_experiment=True,
3132
algorithm="grpo",
32-
experiment_name="aime_swarm_14b_v33_ppoepoch4",
33+
experiment_name="aime_swarm_14b_v33_ppoepoch4_v3",
3334
max_env_worker=128,
3435
n_gpu=8,
3536
model=REMOTE_MODEL_PATH,
@@ -148,6 +149,8 @@ def run_eval(self, n_global_step: int):
148149
"""Run evaluation on AIME-2024 test set."""
149150
if not self.eval_tasks:
150151
return
152+
eval_log_path = os.path.join(self.swarm_worker.server_experiment_dir(), "eval_results.log")
153+
print(eval_log_path)
151154

152155
k = self.EVAL_K
153156
total_rollouts = len(self.eval_tasks) * k
@@ -182,7 +185,6 @@ def run_eval(self, n_global_step: int):
182185
f"n_tasks={len(per_task_rewards)} n_rollouts={len(flat)}"
183186
)
184187
print(summary)
185-
eval_log_path = os.path.join(os.path.dirname(__file__), "eval_results.log")
186188
with open(eval_log_path, "a") as f:
187189
f.write(summary + "\n")
188190
else:
@@ -193,7 +195,7 @@ def run_eval(self, n_global_step: int):
193195
def train(self):
194196
"""Main training loop."""
195197
# Run eval once before training starts (baseline)
196-
# self.run_eval(0)
198+
self.run_eval(0)
197199

198200
task_count = 0
199201
max_parallel = 64

tutorial/opencode_build_aime/agent_run_v3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ async def run(self, messages: list[dict], sampling_params: dict) -> tuple[str, l
393393
except Exception as e:
394394
tool_response = {"text": f"Error executing tool: {e}"}
395395

396-
truncated_text = self._truncate_response(tool_response.get("text", ""))
396+
truncated_text = self._truncate_response(tool_response.get("text", "")) or "(no output)"
397397
formatted_messages.append({
398398
"role": "tool",
399399
"content": truncated_text,

0 commit comments

Comments
 (0)