@@ -262,30 +262,28 @@ def tokenize_steps(
262262
263263 # check reward structure
264264 self .reward_structure : Reward # type: ignore
265- assert (
266- self .reward_structure .step_reward_arr is not None
267- ), "must call `process_reward` before tokenize_steps"
268- assert len (self .reward_structure .step_reward_arr ) == total_steps
265+ assert self .reward_structure .step_reward_arr is not None , "must call `process_reward` before tokenize_steps"
266+ assert len (self .reward_structure .step_reward_arr ) == total_steps , f"reward step count { len (self .reward_structure .step_reward_arr )} != total_steps { total_steps } "
269267
270268 # mapping
271269 input_ids = []
272270 input_logprobs = []
273271 attention_mask = []
274272 loss_mask = []
275- split_prompt_reponse_index = - 1
273+ split_prompt_response_index = - 1
276274 split_point_message_left_index = - 1
277275 input_ids_len = []
278276
279277 # cat all messages
280278 for i , ext_msg in enumerate (ext_steps ):
281279 # find split index, this have to be done before input_ids += ext_msg.token_arr
282- if (split_prompt_reponse_index == - 1 ) and (ext_msg .need_training ):
283- split_prompt_reponse_index = len (input_ids )
280+ if (split_prompt_response_index == - 1 ) and (ext_msg .need_training ):
281+ split_prompt_response_index = len (input_ids )
284282 split_point_message_left_index = i - 1
285283 assert (
286284 split_point_message_left_index >= 0
287285 ), "There should be at least one message before the first training message"
288- assert split_prompt_reponse_index == input_ids_len [split_point_message_left_index ]
286+ assert split_prompt_response_index == input_ids_len [split_point_message_left_index ]
289287 assert (
290288 ext_msg .author == "llm"
291289 ), "The first message after initialization should be from LLM, not from env or user"
@@ -304,37 +302,37 @@ def tokenize_steps(
304302 # move the split index forward
305303 MAX_FORWARD_STEPS = 100
306304 for i in range (MAX_FORWARD_STEPS ):
307- if loss_mask [split_prompt_reponse_index ] == 0 :
308- split_prompt_reponse_index += 1
305+ if loss_mask [split_prompt_response_index ] == 0 :
306+ split_prompt_response_index += 1
309307 else :
310308 break
311309
312310 # no matter what, the split index should not exceed max prompt length
313311 # make sure that the prompt length does not exceed `config.ajet.data.max_prompt_length`
314- if split_prompt_reponse_index > self .config .ajet .data .max_prompt_length :
315- split_prompt_reponse_index = self .config .ajet .data .max_prompt_length
312+ if split_prompt_response_index > self .config .ajet .data .max_prompt_length :
313+ split_prompt_response_index = self .config .ajet .data .max_prompt_length
316314
317315 # check
318316 assert len (ext_steps ) == len (
319317 input_ids_len
320318 ), "length of ext_steps and input_ids_len should be equal"
321319 assert (
322- split_prompt_reponse_index != - 1
323- ), "split_prompt_reponse_index should not be -1, at least one message should be in the context"
320+ split_prompt_response_index != - 1
321+ ), "split_prompt_response_index should not be -1, at least one message should be in the context"
324322 position_ids = compute_position_id_with_mask (torch .tensor (attention_mask )).tolist ()
325323
326324 # sperate prompt and response
327- prompt_ids = input_ids [:split_prompt_reponse_index ]
328- prompt_attention_mask = attention_mask [:split_prompt_reponse_index ]
329- prompt_position_ids = position_ids [:split_prompt_reponse_index ]
330- prompt_loss_mask = loss_mask [:split_prompt_reponse_index ]
331- prompt_logprobs = input_logprobs [:split_prompt_reponse_index ]
332-
333- response_ids = input_ids [split_prompt_reponse_index :]
334- response_attention_mask = attention_mask [split_prompt_reponse_index :]
335- response_position_ids = position_ids [split_prompt_reponse_index :]
336- response_loss_mask = loss_mask [split_prompt_reponse_index :]
337- response_logprobs = input_logprobs [split_prompt_reponse_index :]
325+ prompt_ids = input_ids [:split_prompt_response_index ]
326+ prompt_attention_mask = attention_mask [:split_prompt_response_index ]
327+ prompt_position_ids = position_ids [:split_prompt_response_index ]
328+ prompt_loss_mask = loss_mask [:split_prompt_response_index ]
329+ prompt_logprobs = input_logprobs [:split_prompt_response_index ]
330+
331+ response_ids = input_ids [split_prompt_response_index :]
332+ response_attention_mask = attention_mask [split_prompt_response_index :]
333+ response_position_ids = position_ids [split_prompt_response_index :]
334+ response_loss_mask = loss_mask [split_prompt_response_index :]
335+ response_logprobs = input_logprobs [split_prompt_response_index :]
338336
339337 tracker_tokenized = {}
340338 tracker_tokenized ["input_ids" ] = input_ids
0 commit comments