Skip to content

Commit 4cb513b

Browse files
committed
enhance error logging during tracker.tokenize() for better debugging
1 parent b15983a commit 4cb513b

File tree

2 files changed

+24
-25
lines changed

2 files changed

+24
-25
lines changed

ajet/context_tracker/basic_tracker.py

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -262,30 +262,28 @@ def tokenize_steps(
262262

263263
# check reward structure
264264
self.reward_structure: Reward # type: ignore
265-
assert (
266-
self.reward_structure.step_reward_arr is not None
267-
), "must call `process_reward` before tokenize_steps"
268-
assert len(self.reward_structure.step_reward_arr) == total_steps
265+
assert self.reward_structure.step_reward_arr is not None, "must call `process_reward` before tokenize_steps"
266+
assert len(self.reward_structure.step_reward_arr) == total_steps, f"reward step count {len(self.reward_structure.step_reward_arr)} != total_steps {total_steps}"
269267

270268
# mapping
271269
input_ids = []
272270
input_logprobs = []
273271
attention_mask = []
274272
loss_mask = []
275-
split_prompt_reponse_index = -1
273+
split_prompt_response_index = -1
276274
split_point_message_left_index = -1
277275
input_ids_len = []
278276

279277
# cat all messages
280278
for i, ext_msg in enumerate(ext_steps):
281279
# find split index, this have to be done before input_ids += ext_msg.token_arr
282-
if (split_prompt_reponse_index == -1) and (ext_msg.need_training):
283-
split_prompt_reponse_index = len(input_ids)
280+
if (split_prompt_response_index == -1) and (ext_msg.need_training):
281+
split_prompt_response_index = len(input_ids)
284282
split_point_message_left_index = i - 1
285283
assert (
286284
split_point_message_left_index >= 0
287285
), "There should be at least one message before the first training message"
288-
assert split_prompt_reponse_index == input_ids_len[split_point_message_left_index]
286+
assert split_prompt_response_index == input_ids_len[split_point_message_left_index]
289287
assert (
290288
ext_msg.author == "llm"
291289
), "The first message after initialization should be from LLM, not from env or user"
@@ -304,37 +302,37 @@ def tokenize_steps(
304302
# move the split index forward
305303
MAX_FORWARD_STEPS = 100
306304
for i in range(MAX_FORWARD_STEPS):
307-
if loss_mask[split_prompt_reponse_index] == 0:
308-
split_prompt_reponse_index += 1
305+
if loss_mask[split_prompt_response_index] == 0:
306+
split_prompt_response_index += 1
309307
else:
310308
break
311309

312310
# no matter what, the split index should not exceed max prompt length
313311
# make sure that the prompt length does not exceed `config.ajet.data.max_prompt_length`
314-
if split_prompt_reponse_index > self.config.ajet.data.max_prompt_length:
315-
split_prompt_reponse_index = self.config.ajet.data.max_prompt_length
312+
if split_prompt_response_index > self.config.ajet.data.max_prompt_length:
313+
split_prompt_response_index = self.config.ajet.data.max_prompt_length
316314

317315
# check
318316
assert len(ext_steps) == len(
319317
input_ids_len
320318
), "length of ext_steps and input_ids_len should be equal"
321319
assert (
322-
split_prompt_reponse_index != -1
323-
), "split_prompt_reponse_index should not be -1, at least one message should be in the context"
320+
split_prompt_response_index != -1
321+
), "split_prompt_response_index should not be -1, at least one message should be in the context"
324322
position_ids = compute_position_id_with_mask(torch.tensor(attention_mask)).tolist()
325323

326324
# sperate prompt and response
327-
prompt_ids = input_ids[:split_prompt_reponse_index]
328-
prompt_attention_mask = attention_mask[:split_prompt_reponse_index]
329-
prompt_position_ids = position_ids[:split_prompt_reponse_index]
330-
prompt_loss_mask = loss_mask[:split_prompt_reponse_index]
331-
prompt_logprobs = input_logprobs[:split_prompt_reponse_index]
332-
333-
response_ids = input_ids[split_prompt_reponse_index:]
334-
response_attention_mask = attention_mask[split_prompt_reponse_index:]
335-
response_position_ids = position_ids[split_prompt_reponse_index:]
336-
response_loss_mask = loss_mask[split_prompt_reponse_index:]
337-
response_logprobs = input_logprobs[split_prompt_reponse_index:]
325+
prompt_ids = input_ids[:split_prompt_response_index]
326+
prompt_attention_mask = attention_mask[:split_prompt_response_index]
327+
prompt_position_ids = position_ids[:split_prompt_response_index]
328+
prompt_loss_mask = loss_mask[:split_prompt_response_index]
329+
prompt_logprobs = input_logprobs[:split_prompt_response_index]
330+
331+
response_ids = input_ids[split_prompt_response_index:]
332+
response_attention_mask = attention_mask[split_prompt_response_index:]
333+
response_position_ids = position_ids[split_prompt_response_index:]
334+
response_loss_mask = loss_mask[split_prompt_response_index:]
335+
response_logprobs = input_logprobs[split_prompt_response_index:]
338336

339337
tracker_tokenized = {}
340338
tracker_tokenized["input_ids"] = input_ids

ajet/task_rollout/native_parallel_worker.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,7 @@ def trajectories_to_samples(self, tracker_array: List[BaseContextTracker]) -> Li
606606
except Exception as e:
607607
raise e
608608
finally:
609+
logger.bind(exception=True).exception("Error during tracker.tokenize()") # for debugging
609610
tracker.generate_log(global_step=self.current_global_steps)
610611
if os.environ.get("BEST_LOGGER_PATH", None) and os.environ.get(
611612
"AJET_DEBUG", None

0 commit comments

Comments
 (0)