Skip to content

Commit 5132c2b

Browse files
committed
improve readability
1 parent 4cb513b commit 5132c2b

File tree

7 files changed

+71
-27
lines changed

7 files changed

+71
-27
lines changed

ajet/backbone/trainer_verl.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -563,8 +563,7 @@ def fit(self): # noqa: C901
563563
# pass global_steps to trace
564564
gen_batch.meta_info["global_steps"] = self.global_steps
565565
is_last_step = self.global_steps >= self.total_training_steps
566-
from ajet import bp
567-
bp("BATCH")
566+
568567
with marked_timer("step", timing_raw):
569568
# generate a batch
570569
logger.info("rollout step begin")
@@ -597,6 +596,8 @@ def fit(self): # noqa: C901
597596
context_tracker_arr: List[BaseContextTracker] = self.parallel_env.rollout(
598597
tasks, mode="sample", epoch=f"train.{epoch}"
599598
)
599+
from ajet import bp
600+
bp("BATCH")
600601
logger.info("end fit rollout")
601602
gen_batch_output = self.parallel_env.to_dataproto(context_tracker_arr)
602603
logger.info("end dataproto convertion")

ajet/context_tracker/base_tracker.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -113,34 +113,45 @@ def replace_token_ids(
113113
class BaseTracker(object):
114114
def __init__(self, config, tokenizer, workflow_task: WorkflowTask, **kwargs):
115115

116+
# disable read only mode
117+
self._read_only = False
118+
self._discarded = False
119+
120+
# task related info
116121
self.workflow_task = workflow_task
117122
self.task_batch_index = self.workflow_task.task_batch_index
118123
self.task_tag: str = self.workflow_task.task_tag
119124
self.task_id: str = self.workflow_task.task_id
120125
self.episode_uuid = self.workflow_task.episode_uuid
121126

122-
self.config = config
127+
# tokenizer
123128
self.tokenizer = tokenizer
129+
self.blackout_token_combo = tokenizer.encode("<|im_start|>assistant\n")
130+
self._im_start_token_id = tokenizer.encode("<|im_start|>")[0]
131+
132+
# config
133+
self.config = config
124134
self.saved_timelines: List[List[ExtendedMessage]] = []
125135
self.current_context_status = ""
136+
137+
# length control
126138
max_response_length = self.config.ajet.rollout.max_response_length_in_one_turn
127139
max_model_len: int = self.config.ajet.rollout.max_model_len
128140
self.max_seq_length: int = max_model_len - max_response_length
129-
self.blackout_token_combo = tokenizer.encode("<|im_start|>assistant\n")
130-
self._im_start_token_id = tokenizer.encode("<|im_start|>")[0]
131-
self.generated_token_cnt = 0
132-
self.terminal_rewards_dict = {}
133-
self.discarded = False
134-
self.is_terminated = False
135-
self.reward_structure: Union[Reward, None] = None
136-
self.context_time_cost = 0
141+
142+
self.generation_prompt_token = None
143+
self.log_metrics: Optional[Dict[str, Union[float, List[float], Dict[str, Any]]]] = None # Initialize workflow_metadata to store tool statistics
144+
145+
# meta data attributes
137146
self.tag = ""
147+
self.round_cnt = 0
148+
self.generated_token_cnt = 0
138149
self.current_batch_success_rate: float = float("-inf")
139150
self.current_batch_reward: float = float("-inf")
151+
152+
# reward and madness detection
153+
self.reward_structure: Union[Reward, None] = None
140154
self.already_mad_flag: bool = False
141-
self.round_cnt = 0
142-
self.generation_prompt_token = None
143-
self.log_metrics: Optional[Dict[str, Union[float, List[float], Dict[str, Any]]]] = None # Initialize workflow_metadata to store tool statistics
144155

145156
assert (
146157
self.config.ajet.data.max_prompt_length
@@ -149,13 +160,13 @@ def __init__(self, config, tokenizer, workflow_task: WorkflowTask, **kwargs):
149160
)
150161

151162
def reset(self):
163+
# disable read only mode
164+
self._read_only = False
165+
self._discarded = False
166+
152167
self.saved_timelines: List[List[ExtendedMessage]] = []
153168
self.current_context_status = ""
154-
self.terminal_rewards_dict = {}
155-
self.discarded = False
156-
self.is_terminated = False
157169
self.reward_structure: Union[Reward, None] = None
158-
self.context_time_cost = 0
159170
self.tag = ""
160171
self.current_batch_success_rate: float = float("-inf")
161172
self.current_batch_reward: float = float("-inf")

ajet/context_tracker/basic_tracker.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ class BaseContextTracker(BaseTracker):
2424
full_context (List[ExtendedMessage]): List of all messages in the conversation
2525
current_context_status (str): Current status of the context
2626
max_seq_length (int): Maximum sequence length for the context window
27-
terminal_rewards_dict (dict): Dictionary storing terminal rewards
2827
"""
2928

3029
def __init__(self, config, tokenizer, **kwargs):

ajet/context_tracker/multiagent_tracking.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,12 @@ def step_track(
216216
timeline_uuid: str = "",
217217
):
218218
assert timeline_uuid in self.timeline_cache, "Timeline UUID not found in cache. Please ensure `step_prepare` is called before `step_track`."
219-
timeline = self.timeline_cache.get(timeline_uuid, [])
219+
220+
# round ++
221+
self.round_cnt += 1
222+
223+
# get timeline from cache
224+
timeline = self.timeline_cache.pop(timeline_uuid, [])
220225
if not self.already_mad_flag:
221226
if (
222227
compute_string_madness(
@@ -291,6 +296,11 @@ def save_llm_interaction_timeline(self, tools, llm_ext_msg, timeline):
291296
for i in range(1, len(timeline)):
292297
assert not timeline[i].first_message
293298

299+
# no longer write anything
300+
if self._read_only:
301+
logger.exception("Timeline is in read-only mode, should not save new timeline. Please report a github issue if you see this error.")
302+
return
303+
294304
# save to self.saved_timelines
295305
self.saved_timelines += [copy.deepcopy(timeline)]
296306

@@ -556,6 +566,8 @@ def generate_log(self, task_id=None, global_step="NA"):
556566
def group_merge(self) -> List[List[ExtendedMessage]]:
557567
timeline_merging_policy: TimelineMergingPolicyConfig = self.config.ajet.context_tracker.timeline_merging_policy
558568
self.saved_timelines = merge_tracker_timelines(self.saved_timelines, timeline_merging_policy)
569+
self._read_only = True
570+
559571
return self.saved_timelines
560572

561573

ajet/task_rollout/native_parallel_worker.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def rollout_dynamic( # noqa: C901
255255
completed_task_futures = [f for f in task_future_array if f.done()]
256256
completed_results = [f.result() for f in completed_task_futures]
257257
completed_results = [
258-
tracker for tracker in completed_results if not tracker.discarded
258+
tracker for tracker in completed_results if not tracker._discarded
259259
]
260260
reward = [
261261
tracker.reward_structure.performance_reward for tracker in completed_results
@@ -306,7 +306,7 @@ def rollout_dynamic( # noqa: C901
306306
)
307307
time.sleep(5)
308308

309-
# We have enough number of samples, but we need to wait for all threads to finish, including discarded threads
309+
# We have enough number of samples, but we need to wait for all threads to finish, including ._discarded threads
310310
tic = -1
311311
while any(f.running() for task_future_array in futures for f in task_future_array):
312312
tic += 1
@@ -325,7 +325,7 @@ def rollout_dynamic( # noqa: C901
325325
completed_task_futures = [f for f in task_future_array if f.done()]
326326
completed_results = [f.result() for f in completed_task_futures]
327327
completed_results = [
328-
tracker for tracker in completed_results if not tracker.discarded
328+
tracker for tracker in completed_results if not tracker._discarded
329329
]
330330
task_cmd_reward_array = [
331331
tracker.reward_structure.performance_reward for tracker in completed_results
@@ -409,7 +409,7 @@ def rollout_dynamic( # noqa: C901
409409
completed_task_futures = [f for f in task_future_array if f.done()]
410410
completed_results = [f.result() for f in completed_task_futures]
411411
completed_results = [
412-
tracker for tracker in completed_results if not tracker.discarded
412+
tracker for tracker in completed_results if not tracker._discarded
413413
]
414414
# in-group success rate and reward
415415
task_cmd_reward_array = [
@@ -583,6 +583,19 @@ def stop_all_threads_hard():
583583
for ct_list in completed_task_id_map_ct.values():
584584
tracker_array.extend(ct_list)
585585

586+
587+
# TODO: support multi-step reward
588+
task_success_rate = np.mean(
589+
[tracker.reward_structure.success_rate for tracker in tracker_array]
590+
)
591+
task_scalar_reward = np.mean(
592+
[tracker.reward_structure.final_scalar_reward for tracker in tracker_array]
593+
)
594+
595+
for tracker in tracker_array:
596+
tracker.current_batch_success_rate = float(task_success_rate)
597+
tracker.current_batch_reward = float(task_scalar_reward)
598+
586599
# return all trackers
587600
return tracker_array
588601

ajet/task_rollout/single_worker.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,14 +160,18 @@ def rollout_env_worker_loop(
160160
**kwargs,
161161
):
162162
try:
163+
163164
cnt = 1
165+
164166
while True:
165167

166-
if observation_window["stop"][task_thread_index]:
167-
print('rollout_env_worker_loop received stop signal, exiting...')
168+
if observation_window["stop"][task_thread_index]: # since we use multi-threading, the best way to communicate with main thread is through shared memory.
168169
return
169170

170-
observation_window["info"][task_thread_index] = str(cnt)
171+
observation_window["info"][task_thread_index] = str(cnt) # observe how many iterations have been done in the loop
172+
173+
# Let's begin working on the task, the result `tracker` will contain everything: reward, llm calls, conversation history, etc.
174+
# Later we will gather all trackers and do post-processing, generating samples for VeRL.
171175
tracker = self.rollout_env_worker(
172176
task=task,
173177
task_batch_index=task_batch_index,
@@ -185,7 +189,9 @@ def rollout_env_worker_loop(
185189
completed_task_id_map_ct[tracker.task_id] = [tracker]
186190
else:
187191
completed_task_id_map_ct[tracker.task_id] += [tracker]
192+
188193
cnt += 1
194+
189195
if observation_window["stop"][task_thread_index]:
190196
return
191197
else:

ajet/task_runner/swarm_runner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,13 +130,15 @@ def execute(self, workflow_task: WorkflowTask) -> BaseContextTracker:
130130
print(f'Exiting workflow worker due to interrupt signal for episode {workflow_task.episode_uuid}.')
131131
raise SwarmReceiveAbortException(f"Episode {workflow_task.episode_uuid} aborted due to interrupt signal.")
132132

133+
# context tracker will trace and gather everything we need for training
133134
context_tracker = MultiAgentContextTracker(
134135
llm_inference_fn=self.llm_inference_fn,
135136
tokenizer=self.tokenizer,
136137
config=self.config,
137138
workflow_task = workflow_task,
138139
**hooks,
139140
)
141+
# tuner will handle the communication and provide `baseurl_apikey`
140142
tuner = AjetTuner(
141143
context_tracker=context_tracker,
142144
llm_inference_fn=self.llm_inference_fn,

0 commit comments

Comments
 (0)