Skip to content

Commit b150420

Browse files
committed
patch save dir bug
1 parent 621d235 commit b150420

6 files changed

Lines changed: 26 additions & 24 deletions

File tree

ajet/context_tracker/multiagent_tracking.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,7 @@ def save_llm_interaction_timeline(self, tools, llm_ext_msg, timeline):
319319
)
320320
):
321321
logger.bind(exception=True).info(f"General Warning: merge failure discovered.\n")
322+
# from ajet import bp; bp("SWARM")
322323
return
323324

324325

ajet/launcher.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def main():
214214
"Please provide a valid config file for swarm server mode."
215215
)
216216
if args.conf:
217-
exp_dir = args.exp_dir or DEFAULT_DIR
217+
exp_base_dir = args.exp_dir or DEFAULT_DIR
218218
yaml_path = args.conf
219219
(
220220
main_yaml_fp,
@@ -223,7 +223,7 @@ def main():
223223
exp_config,
224224
) = prepare_experiment_config(
225225
yaml_path=yaml_path,
226-
exp_base_dir=exp_dir,
226+
exp_base_dir=exp_base_dir,
227227
backbone=args.backbone,
228228
storage=(not args.swarm_server)
229229
)

ajet/swarm_cli.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def start_swarm_server(env, config, port):
4242
def cmd_start(args):
4343
"""Handle the 'start' subcommand."""
4444
# Use default config if not provided
45-
exp_dir = args.exp_dir or DEFAULT_DIR
45+
exp_base_dir = args.exp_dir or DEFAULT_DIR
4646
if not args.conf:
4747
args.conf = os.path.abspath(
4848
os.path.join(
@@ -62,7 +62,7 @@ def cmd_start(args):
6262
exp_config,
6363
) = prepare_experiment_config(
6464
yaml_path=yaml_path,
65-
exp_base_dir=exp_dir,
65+
exp_base_dir=exp_base_dir,
6666
backbone="verl",
6767
storage=False
6868
)

ajet/task_rollout/native_parallel_worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def rollout_swarm( # noqa: C901
173173
Build a pool of threads to run context trackers in parallel,
174174
each thread re-spawn after complete, until reaching conditions to stop.
175175
"""
176-
# from ajet import bp; bp("SWARM")
176+
177177
tracker_array: List[SingleAgentContextTracker] = []
178178
rollout_n = self.rollout_n
179179
n_batch_task = len(tasks)

ajet/tuner_lib/experimental/as_swarm_server.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -334,9 +334,11 @@ async def start_engine():
334334
yaml_str = shared_mem_dict["train_config_yaml"]
335335
config_dict = yaml_module.safe_load(yaml_str)
336336
backbone = config_dict.get("ajet", {}).get("backbone", "verl")
337-
exp_base_dir = os.path.dirname(
338-
config_dict.get("ajet", {}).get("experiment_dir", "saved_experiments")
339-
)
337+
DEFAULT_DIR = "saved_experiments"
338+
experiment_dir = config_dict.get("ajet", {}).get("experiment_dir", DEFAULT_DIR)
339+
if experiment_dir == "auto":
340+
exp_base_dir = DEFAULT_DIR
341+
exp_base_dir = os.path.dirname(os.path.abspath(experiment_dir))
340342

341343
# Save YAML to temporary file
342344
with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".yaml") as temp_file:

tutorial/example_werewolves_swarm/agent_roll.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# -*- coding: utf-8 -*-
22

33
import os
4-
from ajet.schema.task import Task
4+
from ajet.schema.task import Task, WorkflowTask
55
from ajet.copilot.job import AgentJetJob
66
from ajet.task_reader import RouterTaskReader
77
from ajet.utils.thread_executors import PeriodicDrainThreadPoolExecutor
@@ -24,37 +24,36 @@ def main():
2424
base_yaml_config="tutorial/example_werewolves_swarm/werewolves.yaml",
2525
algorithm="grpo",
2626
experiment_name="werewolves_swarm",
27+
max_env_worker=128,
2728
)
2829

2930
# Hand shake with remote swarm server
3031
swarm_worker = SwarmClient(AJET_SWARM_URL)
3132
swarm_worker.auto_sync_train_config_and_start_engine(
3233
ajet_job,
33-
force_restart=False,
34+
# force_restart=True,
3435
)
3536

3637
GRPO_N = ajet_job.num_repeat
3738
REMOTE_BATCH_SIZE = ajet_job.batch_size
3839

3940
def rollout(task):
40-
try:
41-
# begin episode
42-
episode_uuid, api_baseurl_key = swarm_worker.begin_episode(discard_episode_timeout=60)
43-
# execute agent ( base_url = api_baseurl_key.base_url, api_key = api_baseurl_key.api_key )
44-
workflow_output = execute_agent(task, api_baseurl_key) # reward is in `workflow_output`
45-
# report output back to swarm remote
46-
swarm_worker.end_episode(task, episode_uuid, workflow_output)
47-
return
48-
except:
49-
pass
41+
# begin episode
42+
episode_uuid, api_baseurl_key = swarm_worker.begin_episode(discard_episode_timeout=240)
43+
# execute agent ( base_url = api_baseurl_key.base_url, api_key = api_baseurl_key.api_key )
44+
workflow_output = execute_agent(task, api_baseurl_key) # reward is in `workflow_output`
45+
# report output back to swarm remote
46+
swarm_worker.end_episode(task, episode_uuid, workflow_output)
47+
return
5048

51-
executor = PeriodicDrainThreadPoolExecutor(workers=GRPO_N * REMOTE_BATCH_SIZE, auto_retry=True)
49+
50+
executor = PeriodicDrainThreadPoolExecutor(workers=1, max_parallel=64, auto_retry=True, block_first_run=True)
5251
for _ in range(NUM_EPOCH):
5352
for _, task in enumerate(dataset.generate_training_tasks()):
5453
for _ in range(GRPO_N):
5554
executor.submit_with_periodic_drain(fn=rollout, task=task)
5655

57-
return None
56+
return
5857

5958

6059
def execute_agent(task: Task, api_baseurl_key: OpenaiBaseUrlAndApiKey):
@@ -63,9 +62,9 @@ def execute_agent(task: Task, api_baseurl_key: OpenaiBaseUrlAndApiKey):
6362
game = ExampleWerewolves(
6463
trainable_targets=["werewolf"],
6564
big_external_opponent_llm_name="Qwen/Qwen3-235B-A22B-Instruct-2507",
66-
big_external_opponent_llm_url="http://22.14.116.243/v1",
65+
big_external_opponent_llm_url="http://22.14.116.243:2888/v1",
6766
)
68-
res = asyncio.run(game.execute(task, api_baseurl_key))
67+
res = asyncio.run(game.execute(WorkflowTask(task=task), api_baseurl_key))
6968
return res
7069

7170

0 commit comments

Comments
 (0)