@@ -152,6 +152,8 @@ def __init__(
152152 self .skipped_sample_count = 0
153153 self .failed_sample_count = 0
154154 self .filtered_samples_count = 0
155+ self ._raw_reward_sum = 0.0
156+ self ._raw_reward_count = 0
155157 self .tb_metrics : Dict [str , Any ] = {}
156158 self .target_batch_size = self .config .global_batch_size
157159 rollout_info = ray .get (self .env_controller .get_rollout_info .remote ()) # type: ignore[attr-defined]
@@ -190,6 +192,8 @@ def _reset_internal_states(
190192 self .skipped_sample_count = 0
191193 self .failed_sample_count = 0
192194 self .filtered_samples_count = 0
195+ self ._raw_reward_sum = 0.0
196+ self ._raw_reward_count = 0
193197 self .tb_metrics = {}
194198 if global_batch_size and global_batch_size > 0 :
195199 self .target_batch_size = global_batch_size
@@ -254,6 +258,15 @@ async def worker_task(self, group_samples_for_retry: Optional[List[RLDataFlowIte
254258 group_state = determine_group_state (group_data_items )
255259 self .logger .debug (f"Determined replay state for { action_id } : { group_state } " )
256260 if group_state == RolloutState .COMPLETED :
261+ # Accumulate raw rewards before post_processor filters samples out.
262+ for item in group_data_items :
263+ reward_data = getattr (item .env , "judger" , None )
264+ if reward_data is not None :
265+ reward_dict = reward_data .reward if hasattr (reward_data , "reward" ) else reward_data
266+ score = reward_dict .get ("score" ) if isinstance (reward_dict , dict ) else None
267+ if score is not None :
268+ self ._raw_reward_sum += score
269+ self ._raw_reward_count += 1
257270 if not self .sample_from_expired_storage :
258271 group_data_items = self .replay_buffer .post_processor (group_data_items ) # type: ignore[attr-defined]
259272 if len (group_data_items ) > 0 :
@@ -490,6 +503,11 @@ def logging_replaybuffer_state(self, logging_msg: Optional[str] = None):
490503 logging_msg += f"skipped_samples_count: { self .skipped_sample_count } , "
491504 logging_msg += f"failed_samples_count: { self .failed_sample_count } , "
492505 logging_msg += f"filtered_samples_count: { self .filtered_samples_count } , "
506+ if self ._raw_reward_count > 0 :
507+ avg_raw_reward = self ._raw_reward_sum / self ._raw_reward_count
508+ logging_msg += f"avg_raw_reward: { avg_raw_reward :.6f} (n={ self ._raw_reward_count } ), "
509+ self .tb_metrics ["reward/avg_raw_reward" ] = avg_raw_reward
510+ self .tb_metrics ["reward/raw_reward_count" ] = self ._raw_reward_count
493511 self .logger .info (logging_msg )
494512
495513 def get_replaybuffer_status (self ):
0 commit comments