Skip to content

Commit b3f6dfa

Browse files
committed
impl multi rollout stop condition
1 parent 98db2b7 commit b3f6dfa

File tree

20 files changed

+315
-427
lines changed

20 files changed

+315
-427
lines changed

ajet/backbone/trainer_verl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -610,10 +610,10 @@ def fit(self): # noqa: C901
610610
traj.reward_structure.madness for traj in context_tracker_arr
611611
]
612612
# reward = [traj.reward_structure.raw_reward for traj in context_tracker_arr]
613-
round_cnt = [traj.round_cnt for traj in context_tracker_arr]
613+
llm_call_cnt = [traj.llm_call_cnt for traj in context_tracker_arr]
614614
metrics.update(
615615
{
616-
"critic/round_cnt": np.mean(round_cnt),
616+
"critic/llm_call_cnt": np.mean(llm_call_cnt),
617617
"critic/madness_rate": np.mean(madness_rate),
618618
"critic/success_rate": np.mean(success_rate),
619619
"critic/real_success_rate": np.mean(

ajet/context_tracker/base_tracker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def __init__(self, config, tokenizer, workflow_task: WorkflowTask, **kwargs):
144144

145145
# meta data attributes
146146
self.tag = ""
147-
self.round_cnt = 0
147+
self.llm_call_cnt = 0
148148
self.generated_token_cnt = 0
149149
self.current_batch_success_rate: float = float("-inf")
150150
self.current_batch_reward: float = float("-inf")
@@ -171,7 +171,7 @@ def reset(self):
171171
self.current_batch_success_rate: float = float("-inf")
172172
self.current_batch_reward: float = float("-inf")
173173
self.already_mad_flag: bool = False
174-
self.round_cnt = 0
174+
self.llm_call_cnt = 0
175175
self.log_metrics: Optional[Dict[str, Union[float, List[float], Dict[str, Any]]]] = None
176176

177177
def group_tokenize(self):

ajet/context_tracker/multiagent_tracking.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,8 @@ def step_prepare(self, messages: List[dict], tools: List = [], timeline_uuid: st
201201
custom_sampling_params = {}
202202
if not context_safe:
203203
self.context_overflow = True
204+
logger.warning(f"[{self.workflow_task.episode_uuid}] Stop tracking timelines because {info}.")
205+
204206

205207
self.timeline_cache[timeline_uuid] = timeline
206208
return context_safe, token_overflow, info, converted_message, custom_sampling_params, tools
@@ -218,7 +220,7 @@ def step_track(
218220
assert timeline_uuid in self.timeline_cache, "Timeline UUID not found in cache. Please ensure `step_prepare` is called before `step_track`."
219221

220222
# round ++
221-
self.round_cnt += 1
223+
self.llm_call_cnt += 1
222224

223225
# get timeline from cache
224226
timeline = self.timeline_cache.pop(timeline_uuid, [])

ajet/default_config/ajet_default.yaml

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,17 @@ ajet:
293293
max_fastapi_threads: 512 # 64 or 128 is fine
294294
max_inference_tracker_threads: 64 # recommend to be equal to `ajet.rollout.max_env_worker`
295295
already_started: False # do not edit, used by `swarm`
296-
296+
# what is the stop condition for swarm mode sample collection
297+
# "rollout_until_finish_enough_episodes":
298+
# AgentJet simply counts the number of completed episodes, and stop when it has collected [>= (ajet.data.train_batch_size * ajet.rollout.num_repeat)] samples
299+
# "rollout_until_finish_enough_tasks":
300+
# AgentJet will identify the **task_id** of each episode, and stop when it has collected [>= ajet.data.train_batch_size] unique & FINISHED **task_id**.
301+
# (Hint: a **task_id** is considered "FINISHED" when [>= ajet.rollout.num_repeat] episodes of this **task_id** have been completed.)
302+
# "rollout_until_finish_enough_non_dummy_tasks":
303+
# AgentJet will identify the **task_id** of each episode, and stop when it has collected [>= ajet.data.train_batch_size] unique & FINISHED & NON-DUMMY **task_id**.
304+
# (Hint: a **task_id** is considered "NON-DUMMY" at least one of **episodes** of **task_id** has **different** reward value.)
305+
swarm_mode_sample_collection_method: "rollout_until_finish_enough_tasks"
306+
swarm_mode_sample_collection_max_cached_episodes: 9999
297307

298308
task_runner:
299309
# submit llm infer submit method
@@ -303,7 +313,6 @@ ajet:
303313
wrapper_type: "asyncio-with-gc"
304314
# - wrapper_type: "asyncio-with-gc": safe, with periodic garbage collection to prevent event loop leaks (recommended)
305315
# - wrapper_type: "asyncio": fast, but may cause event loop leak in long run
306-
# - wrapper_type: "multi-processing": safe, but resource consuming
307316

308317
# when `wrapper_type` is `multi-processing`, the timeout for each task
309318
wrapper_multiprocessing_timeout: 3600 # in seconds

ajet/default_config/ajet_ts_default.yaml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,17 @@ ajet:
3333
max_inference_tracker_threads: 64 # recommend to be equal to `ajet.rollout.max_env_worker`
3434
already_started: False # do not edit, used by `swarm`
3535

36+
# the method to determine when to stop rollout in swarm mode. Options:
37+
# "rollout_until_finish_enough_episodes":
38+
# AgentJet simply counts the number of completed episodes, and stop when it has collected [>= (ajet.data.train_batch_size * ajet.rollout.num_repeat)] samples
39+
# "rollout_until_finish_enough_tasks":
40+
# AgentJet will identify the **task_id** of each episode, and stop when it has collected [>= ajet.data.train_batch_size] unique & FINISHED **task_id**.
41+
# (Hint: a **task_id** is considered "FINISHED" when [>= ajet.rollout.num_repeat] episodes of this **task_id** have been completed.)
42+
# "rollout_until_finish_enough_non_dummy_tasks":
43+
# AgentJet will identify the **task_id** of each episode, and stop when it has collected [>= ajet.data.train_batch_size] unique & FINISHED & NON-DUMMY **task_id**.
44+
# (Hint: a **task_id** is considered "NON-DUMMY" at least one of **episodes** of **task_id** has **different** reward value.)
45+
swarm_mode_sample_collection_method: "rollout_until_finish_enough_tasks"
46+
3647
rollout:
3748
# maximum number of parallel environments / simulate workers
3849
max_env_worker: 128

ajet/default_config/verl/verl_default.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ ajet:
66
rollout:
77
step_skip_action: 0
88
submit_oversample_multiplier: 1.5
9-
enable_oversample: False
109

1110
actor_rollout_ref:
1211
actor:

ajet/task_rollout/async_llm_bridge.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,28 +5,23 @@
55
import uuid
66
from typing import Any, Callable, Dict, List, Literal, Union
77

8-
9-
108
from loguru import logger
119
from omegaconf import DictConfig
1210
from pydantic import BaseModel
1311
from vllm.entrypoints.openai.tool_parsers.hermes_tool_parser import Hermes2ProToolParser
1412
from vllm.outputs import RequestOutput as VerlVllmRequestOutput
15-
1613
from agentscope.model import ChatResponse as AgentScopeChatResponse
1714
from openai.types.chat.chat_completion import ChatCompletion as OpenAIChatCompletion
1815

19-
ChatResponse = Union[OpenAIChatCompletion, AgentScopeChatResponse]
20-
21-
from ajet.context_tracker.multiagent_tracking import (
22-
MultiAgentContextTracker,
23-
)
24-
from ajet.schema.convertion import convert_llm_proxy_response_to_oai_response
25-
from ajet.schema.convertion import convert_llm_proxy_response_to_agentscope_response
2616
from ajet.schema.logprob import TokenAndProb
17+
from ajet.utils.tokenizer import ajet_apply_chat_template
2718
from ajet.utils.async_utils import run_async_coroutine_with_timeout
2819
from ajet.utils.testing_utils import _mock_if_test_mode, _test_if_test_mode
29-
from ajet.utils.tokenizer import ajet_apply_chat_template
20+
from ajet.schema.convertion import convert_llm_proxy_response_to_oai_response
21+
from ajet.schema.convertion import convert_llm_proxy_response_to_agentscope_response
22+
from ajet.context_tracker.multiagent_tracking import MultiAgentContextTracker
23+
24+
ChatResponse = Union[OpenAIChatCompletion, AgentScopeChatResponse]
3025

3126

3227
class AjetStandardLlmBridgeRequest(BaseModel):

0 commit comments

Comments
 (0)