11# -*- coding: utf-8 -*-
22
33import os
4- from ajet .schema .task import Task
4+ from ajet .schema .task import Task , WorkflowTask
55from ajet .copilot .job import AgentJetJob
66from ajet .task_reader import RouterTaskReader
77from ajet .utils .thread_executors import PeriodicDrainThreadPoolExecutor
@@ -24,37 +24,36 @@ def main():
2424 base_yaml_config = "tutorial/example_werewolves_swarm/werewolves.yaml" ,
2525 algorithm = "grpo" ,
2626 experiment_name = "werewolves_swarm" ,
27+ max_env_worker = 128 ,
2728 )
2829
2930 # Hand shake with remote swarm server
3031 swarm_worker = SwarmClient (AJET_SWARM_URL )
3132 swarm_worker .auto_sync_train_config_and_start_engine (
3233 ajet_job ,
33- force_restart = False ,
34+ # force_restart=True ,
3435 )
3536
3637 GRPO_N = ajet_job .num_repeat
3738 REMOTE_BATCH_SIZE = ajet_job .batch_size
3839
3940 def rollout (task ):
40- try :
41- # begin episode
42- episode_uuid , api_baseurl_key = swarm_worker .begin_episode (discard_episode_timeout = 60 )
43- # execute agent ( base_url = api_baseurl_key.base_url, api_key = api_baseurl_key.api_key )
44- workflow_output = execute_agent (task , api_baseurl_key ) # reward is in `workflow_output`
45- # report output back to swarm remote
46- swarm_worker .end_episode (task , episode_uuid , workflow_output )
47- return
48- except :
49- pass
41+ # begin episode
42+ episode_uuid , api_baseurl_key = swarm_worker .begin_episode (discard_episode_timeout = 240 )
43+ # execute agent ( base_url = api_baseurl_key.base_url, api_key = api_baseurl_key.api_key )
44+ workflow_output = execute_agent (task , api_baseurl_key ) # reward is in `workflow_output`
45+ # report output back to swarm remote
46+ swarm_worker .end_episode (task , episode_uuid , workflow_output )
47+ return
5048
51- executor = PeriodicDrainThreadPoolExecutor (workers = GRPO_N * REMOTE_BATCH_SIZE , auto_retry = True )
49+
50+ executor = PeriodicDrainThreadPoolExecutor (workers = 1 , max_parallel = 64 , auto_retry = True , block_first_run = True )
5251 for _ in range (NUM_EPOCH ):
5352 for _ , task in enumerate (dataset .generate_training_tasks ()):
5453 for _ in range (GRPO_N ):
5554 executor .submit_with_periodic_drain (fn = rollout , task = task )
5655
57- return None
56+ return
5857
5958
6059def execute_agent (task : Task , api_baseurl_key : OpenaiBaseUrlAndApiKey ):
@@ -63,9 +62,9 @@ def execute_agent(task: Task, api_baseurl_key: OpenaiBaseUrlAndApiKey):
6362 game = ExampleWerewolves (
6463 trainable_targets = ["werewolf" ],
6564 big_external_opponent_llm_name = "Qwen/Qwen3-235B-A22B-Instruct-2507" ,
66- big_external_opponent_llm_url = "http://22.14.116.243/v1" ,
65+ big_external_opponent_llm_url = "http://22.14.116.243:2888 /v1" ,
6766 )
68- res = asyncio .run (game .execute (task , api_baseurl_key ))
67+ res = asyncio .run (game .execute (WorkflowTask ( task = task ) , api_baseurl_key ))
6968 return res
7069
7170
0 commit comments