Skip to content

Commit 7ca0434

Browse files
authored
[Feat] Add raw reward logging in DataFlow (#1633)
Add raw reward accumulation and logging in RawDataFlow - Introduced `_raw_reward_sum` and `_raw_reward_count` to track raw rewards during data processing. - Implemented logic to accumulate raw rewards before filtering samples in the `determine_group_state` method. - Updated logging to include average raw reward and count in the metrics output.
1 parent d3f185d commit 7ca0434

1 file changed

Lines changed: 18 additions & 0 deletions

File tree

xtuner/v1/ray/dataflow/flow.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,8 @@ def __init__(
152152
self.skipped_sample_count = 0
153153
self.failed_sample_count = 0
154154
self.filtered_samples_count = 0
155+
self._raw_reward_sum = 0.0
156+
self._raw_reward_count = 0
155157
self.tb_metrics: Dict[str, Any] = {}
156158
self.target_batch_size = self.config.global_batch_size
157159
rollout_info = ray.get(self.env_controller.get_rollout_info.remote()) # type: ignore[attr-defined]
@@ -190,6 +192,8 @@ def _reset_internal_states(
190192
self.skipped_sample_count = 0
191193
self.failed_sample_count = 0
192194
self.filtered_samples_count = 0
195+
self._raw_reward_sum = 0.0
196+
self._raw_reward_count = 0
193197
self.tb_metrics = {}
194198
if global_batch_size and global_batch_size > 0:
195199
self.target_batch_size = global_batch_size
@@ -254,6 +258,15 @@ async def worker_task(self, group_samples_for_retry: Optional[List[RLDataFlowIte
254258
group_state = determine_group_state(group_data_items)
255259
self.logger.debug(f"Determined replay state for {action_id}: {group_state}")
256260
if group_state == RolloutState.COMPLETED:
261+
# Accumulate raw rewards before post_processor filters samples out.
262+
for item in group_data_items:
263+
reward_data = getattr(item.env, "judger", None)
264+
if reward_data is not None:
265+
reward_dict = reward_data.reward if hasattr(reward_data, "reward") else reward_data
266+
score = reward_dict.get("score") if isinstance(reward_dict, dict) else None
267+
if score is not None:
268+
self._raw_reward_sum += score
269+
self._raw_reward_count += 1
257270
if not self.sample_from_expired_storage:
258271
group_data_items = self.replay_buffer.post_processor(group_data_items) # type: ignore[attr-defined]
259272
if len(group_data_items) > 0:
@@ -490,6 +503,11 @@ def logging_replaybuffer_state(self, logging_msg: Optional[str] = None):
490503
logging_msg += f"skipped_samples_count: {self.skipped_sample_count}, "
491504
logging_msg += f"failed_samples_count: {self.failed_sample_count}, "
492505
logging_msg += f"filtered_samples_count: {self.filtered_samples_count}, "
506+
if self._raw_reward_count > 0:
507+
avg_raw_reward = self._raw_reward_sum / self._raw_reward_count
508+
logging_msg += f"avg_raw_reward: {avg_raw_reward:.6f} (n={self._raw_reward_count}), "
509+
self.tb_metrics["reward/avg_raw_reward"] = avg_raw_reward
510+
self.tb_metrics["reward/raw_reward_count"] = self._raw_reward_count
493511
self.logger.info(logging_msg)
494512

495513
def get_replaybuffer_status(self):

0 commit comments

Comments
 (0)