@@ -113,34 +113,45 @@ def replace_token_ids(
113113class BaseTracker (object ):
114114 def __init__ (self , config , tokenizer , workflow_task : WorkflowTask , ** kwargs ):
115115
116+ # disable read only mode
117+ self ._read_only = False
118+ self ._discarded = False
119+
120+ # task related info
116121 self .workflow_task = workflow_task
117122 self .task_batch_index = self .workflow_task .task_batch_index
118123 self .task_tag : str = self .workflow_task .task_tag
119124 self .task_id : str = self .workflow_task .task_id
120125 self .episode_uuid = self .workflow_task .episode_uuid
121126
122- self . config = config
127+ # tokenizer
123128 self .tokenizer = tokenizer
129+ self .blackout_token_combo = tokenizer .encode ("<|im_start|>assistant\n " )
130+ self ._im_start_token_id = tokenizer .encode ("<|im_start|>" )[0 ]
131+
132+ # config
133+ self .config = config
124134 self .saved_timelines : List [List [ExtendedMessage ]] = []
125135 self .current_context_status = ""
136+
137+ # length control
126138 max_response_length = self .config .ajet .rollout .max_response_length_in_one_turn
127139 max_model_len : int = self .config .ajet .rollout .max_model_len
128140 self .max_seq_length : int = max_model_len - max_response_length
129- self .blackout_token_combo = tokenizer .encode ("<|im_start|>assistant\n " )
130- self ._im_start_token_id = tokenizer .encode ("<|im_start|>" )[0 ]
131- self .generated_token_cnt = 0
132- self .terminal_rewards_dict = {}
133- self .discarded = False
134- self .is_terminated = False
135- self .reward_structure : Union [Reward , None ] = None
136- self .context_time_cost = 0
141+
142+ self .generation_prompt_token = None
143+ self .log_metrics : Optional [Dict [str , Union [float , List [float ], Dict [str , Any ]]]] = None # Initialize workflow_metadata to store tool statistics
144+
145+ # meta data attributes
137146 self .tag = ""
147+ self .round_cnt = 0
148+ self .generated_token_cnt = 0
138149 self .current_batch_success_rate : float = float ("-inf" )
139150 self .current_batch_reward : float = float ("-inf" )
151+
152+ # reward and madness detection
153+ self .reward_structure : Union [Reward , None ] = None
140154 self .already_mad_flag : bool = False
141- self .round_cnt = 0
142- self .generation_prompt_token = None
143- self .log_metrics : Optional [Dict [str , Union [float , List [float ], Dict [str , Any ]]]] = None # Initialize workflow_metadata to store tool statistics
144155
145156 assert (
146157 self .config .ajet .data .max_prompt_length
@@ -149,13 +160,13 @@ def __init__(self, config, tokenizer, workflow_task: WorkflowTask, **kwargs):
149160 )
150161
151162 def reset (self ):
163+ # disable read only mode
164+ self ._read_only = False
165+ self ._discarded = False
166+
152167 self .saved_timelines : List [List [ExtendedMessage ]] = []
153168 self .current_context_status = ""
154- self .terminal_rewards_dict = {}
155- self .discarded = False
156- self .is_terminated = False
157169 self .reward_structure : Union [Reward , None ] = None
158- self .context_time_cost = 0
159170 self .tag = ""
160171 self .current_batch_success_rate : float = float ("-inf" )
161172 self .current_batch_reward : float = float ("-inf" )
0 commit comments