Skip to content

Commit 58959a6

Browse files
committed
Add training model path to CurrentBatchRolloutPoolInformation and update logging in sync_train_config
1 parent 6945c38 commit 58959a6

3 files changed

Lines changed: 16 additions & 1 deletion

File tree

ajet/tuner_lib/experimental/swarm_overwatch_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@ class CurrentBatchRolloutPoolInformation(BaseModel):
1818
engine_status: str | None = None
1919
global_step: int | None = None
2020
booting_start_time: float | None = None # timestamp when ENGINE.BOOTING started
21+
training_model_path: str | None = None # model path from synced training config

ajet/tuner_lib/experimental/swarm_server.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,16 +292,25 @@ async def sync_train_config(req: SyncTrainConfigRequest):
292292
)
293293

294294
try:
295+
import yaml as yaml_module
295296
yaml_str = req.yaml_as_string
296297
logger.info("[sync_train_config] Received training configuration")
297298
if DEBUG:
298299
logger.debug(f"[sync_train_config] YAML content:\n{yaml_str}...")
299300

301+
# Extract model path from YAML config
302+
try:
303+
config_dict = yaml_module.safe_load(yaml_str)
304+
model_path = config_dict.get("ajet", {}).get("model", {}).get("path", None)
305+
except Exception:
306+
model_path = None
307+
300308
# Store the YAML config in shared memory for start_engine to use
301309
with shared_mem_dict_lock:
302310
shared_mem_dict["train_config_yaml"] = yaml_str
311+
shared_mem_dict["training_model_path"] = model_path
303312

304-
logger.info("[sync_train_config] Successfully stored training configuration")
313+
logger.info(f"[sync_train_config] Successfully stored training configuration (model: {model_path})")
305314
return {"success": True}
306315
except Exception as e:
307316
logger.error(f"[sync_train_config] Error: {e}")
@@ -749,6 +758,7 @@ async def get_current_batch_rollout_pool_information():
749758
pool_info.engine_status = shared_mem_dict.get("engine_status", None)
750759
pool_info.global_step = shared_mem_dict.get("global_step", None)
751760
pool_info.booting_start_time = shared_mem_dict.get("booting_start_time", None)
761+
pool_info.training_model_path = shared_mem_dict.get("training_model_path", None)
752762

753763
# Build running_episode_details for claimed episodes
754764
running_episode_details = {}

ajet/utils/swarm_overwatch.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ def create_header(
8181

8282
# Add engine status and global step if available
8383
if info:
84+
if info.training_model_path:
85+
header_text.append(
86+
f"\nTraining Model: {info.training_model_path}", style="bold white"
87+
)
8488
if info.engine_status:
8589
header_text.append(
8690
f"\nEngine Status: {info.engine_status}", style="bold yellow"

0 commit comments

Comments
 (0)