Skip to content

Commit e9bc0e1

Browse files
committed
refactor: update configuration and improve swarm client functionality
1 parent 544bff1 commit e9bc0e1

File tree

8 files changed

+26
-28
lines changed

8 files changed

+26
-28
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,3 +170,4 @@ modelscope_cache
170170
prompts
171171
swarmexp
172172
swarmlog
173+
werewolves_swarm

ajet/default_config/ajet_default.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ ajet:
33
project_name: "ajet_default_project"
44
experiment_name: "read_yaml_name"
55
experiment_dir: "auto" # {exp-dir}/{experiment_name}
6-
backbone: debug # `debug` or `trinity` or `verl`
6+
backbone: verl # `debug` or `trinity` or `verl`
77

88

99
model:

ajet/tuner_lib/as_oai_baseurl_apikey.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,11 @@ class OpenaiBaseUrlAndApiKey(BaseModel):
2929
episode_uuid: str = Field(default="episode_id", description="reserved field.")
3030

3131
def as_agentscope_model(self, *args, **kwargs):
32-
from agentscope.model import DashScopeChatModel
33-
return DashScopeChatModel(model_name="AgentJet-Model", api_key=self.api_key, base_http_api_url=self.base_url)
32+
from agentscope.model import OpenAIChatModel
33+
return OpenAIChatModel(
34+
model_name="AgentJet-Model", api_key=self.api_key,
35+
client_args={"base_url": self.base_url}
36+
)
3437

3538
def as_raw_openai_sdk_client(self, *args, **kwargs):
3639
from openai import AsyncOpenAI

ajet/tuner_lib/experimental/as_swarm_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def _should_throttle(self, throttle_policy: SwarmThrottlePolicy, pool_info: Curr
195195
self._remember_seen_task(throttle_policy.current_task_id, throttle_policy.expected_batch_size, throttle_policy.expected_num_repeat)
196196
return should_throttle
197197

198-
def begin_episode(self, discard_episode_timeout=600, episode_type="train", throttle_policy: SwarmThrottlePolicy|None = None) -> Tuple[str, OpenaiBaseUrlAndApiKey]:
198+
def begin_episode(self, discard_episode_timeout=240, episode_type="train", throttle_policy: SwarmThrottlePolicy|None = None) -> Tuple[str, OpenaiBaseUrlAndApiKey]:
199199
"""
200200
Block until an episode is claimed.
201201
Argument:
@@ -210,7 +210,7 @@ def begin_episode(self, discard_episode_timeout=600, episode_type="train", throt
210210
"""
211211
return self._begin_episode_auto_retry(discard_episode_timeout, episode_type, throttle_policy)
212212

213-
def _begin_episode_auto_retry(self, discard_episode_timeout=600, episode_type="train", throttle_policy: SwarmThrottlePolicy|None = None) -> Tuple[str, OpenaiBaseUrlAndApiKey]:
213+
def _begin_episode_auto_retry(self, discard_episode_timeout=240, episode_type="train", throttle_policy: SwarmThrottlePolicy|None = None) -> Tuple[str, OpenaiBaseUrlAndApiKey]:
214214
# max_episode_time: when an episode has **lasted** for more than X seconds, it will be terminated **locally** by client (call `end_episode` will be re-route to `abort_episode`)
215215
max_episode_time = 2*discard_episode_timeout
216216

ajet/tuner_lib/experimental/as_swarm_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -708,7 +708,7 @@ async def get_episode_buffer():
708708
@app.post("/update_current_batch_rollout_pool_information", response_model=BoolResponse)
709709
async def update_current_batch_rollout_pool_information(req: CurrentBatchRolloutPoolInformation):
710710
"""Update the current batch rollout pool information."""
711-
if VERBOSE:
711+
if DEBUG:
712712
logger.info(f"Running /update_current_batch_rollout_pool_information")
713713
try:
714714
with shared_mem_dict_lock:

tutorial/example_math_swarm/math.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,16 +52,13 @@ def main():
5252
)
5353

5454
def rollout(task):
55-
try:
56-
# begin episode
57-
episode_uuid, api_baseurl_key = swarm_worker.begin_episode(discard_episode_timeout=60)
58-
# execute agent ( base_url = api_baseurl_key.base_url, api_key = api_baseurl_key.api_key )
59-
workflow_output = execute_agent(task, api_baseurl_key) # reward is in `workflow_output`
60-
# report output back to swarm remote
61-
swarm_worker.end_episode(task, episode_uuid, workflow_output)
62-
return
63-
except:
64-
pass
55+
# begin episode
56+
episode_uuid, api_baseurl_key = swarm_worker.begin_episode(discard_episode_timeout=60)
57+
# execute agent ( base_url = api_baseurl_key.base_url, api_key = api_baseurl_key.api_key )
58+
workflow_output = execute_agent(task, api_baseurl_key) # reward is in `workflow_output`
59+
# report output back to swarm remote
60+
swarm_worker.end_episode(task, episode_uuid, workflow_output)
61+
return
6562

6663
executor = PeriodicDrainThreadPoolExecutor(workers=GRPO_N * REMOTE_BATCH_SIZE, auto_retry=True)
6764
for _ in range(NUM_EPOCH):

tutorial/example_werewolves/start.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from agentscope.agent import ReActAgent
1414
from agentscope.formatter import DashScopeMultiAgentFormatter, OpenAIMultiAgentFormatter
15-
from agentscope.model import OpenAIChatModel
15+
from agentscope.model import OpenAIChatModel, DashScopeChatModel
1616
from loguru import logger
1717
from pydantic import Field
1818

@@ -81,8 +81,8 @@ def get_official_agent_prompt(name) -> str:
8181

8282
class ExampleWerewolves(Workflow):
8383
trainable_targets: List[str] | None = Field(default=["werewolf"], description="List of agents to be fine-tuned.")
84-
big_external_opponent_llm_url = "http://22.17.52.4:2888/v1"
85-
big_external_opponent_llm_name = "/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen3-235B-A22B-Instruct-2507/"
84+
big_external_opponent_llm_url: str = Field(default="http://22.17.52.4:2888/v1", description="The URL of the big external opponent LLM. You can replace it with any OpenAI-compatible LLM API URL.")
85+
big_external_opponent_llm_name: str = Field(default="/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen3-235B-A22B-Instruct-2507/", description="The model name of the big external opponent LLM. You can replace it with any OpenAI-compatible LLM name.")
8686

8787
async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput:
8888

@@ -121,9 +121,7 @@ async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> Workfl
121121
name=f"Player{i + 1}",
122122
sys_prompt=get_official_agent_prompt(f"Player{i + 1}"),
123123
model=model_for_this_agent,
124-
formatter=DashScopeMultiAgentFormatter()
125-
if role in self.trainable_targets
126-
else OpenAIMultiAgentFormatter(),
124+
formatter=DashScopeMultiAgentFormatter() if isinstance(model_for_this_agent, DashScopeChatModel) else OpenAIMultiAgentFormatter(),
127125
max_iters=3 if role in self.trainable_targets else 5,
128126
)
129127
# agent.set_console_output_enabled(False)

tutorial/example_werewolves_swarm/agent_roll.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,12 @@
11
# -*- coding: utf-8 -*-
22

33
import os
4-
import re
5-
import requests
6-
from textwrap import dedent
7-
from ajet.schema.task import Task, WorkflowOutput
4+
from ajet.schema.task import Task
85
from ajet.copilot.job import AgentJetJob
96
from ajet.task_reader import RouterTaskReader
107
from ajet.utils.thread_executors import PeriodicDrainThreadPoolExecutor
118
from ajet.tuner_lib.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey
12-
from ajet.default_config.ajet_default import AjetTaskReader, HuggingfaceDatRepo
9+
from ajet.default_config.ajet_default import AjetTaskReader
1310
from ajet.tuner_lib.experimental.as_swarm_client import SwarmClient
1411

1512
NUM_EPOCH = 10000
@@ -33,7 +30,7 @@ def main():
3330
swarm_worker = SwarmClient(AJET_SWARM_URL)
3431
swarm_worker.auto_sync_train_config_and_start_engine(
3532
ajet_job,
36-
force_restart=True,
33+
force_restart=False,
3734
)
3835

3936
GRPO_N = ajet_job.num_repeat
@@ -65,6 +62,8 @@ def execute_agent(task: Task, api_baseurl_key: OpenaiBaseUrlAndApiKey):
6562
from tutorial.example_werewolves.start import ExampleWerewolves
6663
game = ExampleWerewolves(
6764
trainable_targets=["werewolf"],
65+
big_external_opponent_llm_name="Qwen/Qwen3-235B-A22B-Instruct-2507",
66+
big_external_opponent_llm_url="http://22.16.90.187/v1",
6867
)
6968
res = asyncio.run(game.execute(task, api_baseurl_key))
7069
return res

0 commit comments

Comments
 (0)