@@ -63,6 +63,49 @@ class ReplayMeta:
6363 extra_info : Dict [str , Any ] = field (default_factory = dict )
6464
6565
66+ def summarize_group_payload (grouped_dataitem : List [RLDataFlowItem ]) -> Dict [str , Any ]:
67+ summary : Dict [str , Any ] = {
68+ "payload_mode" : "full" ,
69+ "observation_count" : len (grouped_dataitem ),
70+ "response_tokens" : 0 ,
71+ "response_chars" : 0 ,
72+ "versioned_segments" : 0 ,
73+ "versioned_tokens" : 0 ,
74+ "routed_expert_payloads" : 0 ,
75+ "judged_observations" : 0 ,
76+ "has_multimodal_prompt" : False ,
77+ }
78+ if not grouped_dataitem :
79+ return summary
80+
81+ first_data = grouped_dataitem [0 ].data
82+ summary ["has_multimodal_prompt" ] = bool (
83+ getattr (first_data , "multimodal_train_info" , None ) and len (first_data .multimodal_train_info ) > 0
84+ )
85+
86+ for item in grouped_dataitem :
87+ rollout = item .env .rollout
88+ judger = item .env .judger
89+ response_ids = rollout .response_ids or []
90+ response_text = rollout .response or ""
91+ versioned_response_ids = rollout .versioned_response_ids or []
92+ versioned_num_return_tokens = rollout .versioned_num_return_tokens or []
93+
94+ summary ["response_tokens" ] += len (response_ids )
95+ summary ["response_chars" ] += len (response_text )
96+ summary ["versioned_segments" ] += len (versioned_response_ids )
97+ if versioned_num_return_tokens :
98+ summary ["versioned_tokens" ] += sum (versioned_num_return_tokens )
99+ else :
100+ summary ["versioned_tokens" ] += sum (len (ids ) for ids in versioned_response_ids )
101+ if rollout .extra_info .get ("routed_experts" , None ) is not None :
102+ summary ["routed_expert_payloads" ] += 1
103+ if judger .uid is not None or judger .reward .get ("score" , 0.0 ) != 0.0 or len (judger .extra_info ) > 0 :
104+ summary ["judged_observations" ] += 1
105+
106+ return summary
107+
108+
66109def determine_group_state (group_data_items : List [RLDataFlowItem ]) -> RolloutState :
67110 """Determines the processing strategy for a group of rollout samples based
68111 on their state."""
@@ -113,7 +156,7 @@ def mapping_dataitem_to_replaymeta(grouped_dataitem: List[RLDataFlowItem]) -> Re
113156 observation_refs = observation_refs ,
114157 state = group_state ,
115158 version = group_version ,
116- extra_info = {} ,
159+ extra_info = summarize_group_payload ( grouped_dataitem ) ,
117160 )
118161 return replay_meta
119162
@@ -323,6 +366,87 @@ def __init__(self, replay_buffer_cfg):
323366 self .sample_from_aborted_count = 0
324367 self .sample_from_expired_count = 0
325368
369+ def _free_replay_meta_refs (self , replay_meta : ReplayMeta , include_action_ref : bool = True ):
370+ refs = []
371+ if include_action_ref and replay_meta .action_ref is not None :
372+ refs .append (replay_meta .action_ref )
373+ refs .extend ([ref for ref in replay_meta .observation_refs if ref is not None ])
374+ if refs :
375+ ray .internal .free (refs , local_only = False )
376+
377+ def _update_replay_meta_state (self , replay_meta : ReplayMeta , new_state : RolloutState ):
378+ for observation_id in replay_meta .observation_ids :
379+ old_state = self ._observations2states .get (observation_id )
380+ if old_state and observation_id in self ._states .get (old_state , []):
381+ self ._states [old_state ].remove (observation_id )
382+ self ._observations2states [observation_id ] = new_state
383+ if observation_id not in self ._states [new_state ]:
384+ self ._states [new_state ].append (observation_id )
385+ replay_meta .state = new_state
386+
387+ def _strip_rollout_payload_for_rerun (self , replay_meta : ReplayMeta , new_state : RolloutState ):
388+ """Keep prompt refs only and drop rollout outputs that will not be reused."""
389+ old_obs_refs = [ref for ref in replay_meta .observation_refs if ref is not None ]
390+ if old_obs_refs :
391+ ray .internal .free (old_obs_refs , local_only = False )
392+ replay_meta .observation_refs = [ray .put (RLEnvDataItem ()) for _ in replay_meta .observation_ids ]
393+ replay_meta .extra_info .update (
394+ {
395+ "payload_mode" : "prompt_only" ,
396+ "response_tokens" : 0 ,
397+ "response_chars" : 0 ,
398+ "versioned_segments" : 0 ,
399+ "versioned_tokens" : 0 ,
400+ "routed_expert_payloads" : 0 ,
401+ "judged_observations" : 0 ,
402+ }
403+ )
404+ self ._update_replay_meta_state (replay_meta , new_state )
405+
406+ def get_storage_stats (self ) -> Dict [str , float ]:
407+ stats : Dict [str , float ] = {
408+ "tracked_actions_count" : float (len (self ._actions )),
409+ "tracked_roots_count" : float (len (self ._root2actions )),
410+ "tracked_observations_count" : float (len (self ._observations )),
411+ "completed_actions_count" : float (sum (len (bucket ) for bucket in self ._completed_actions .values ())),
412+ "aborted_actions_count" : float (sum (len (bucket ) for bucket in self ._aborted_actions .values ())),
413+ "expired_actions_count" : float (len (self ._expired_actions )),
414+ "completed_versions_count" : float (len (self ._completed_actions )),
415+ "aborted_versions_count" : float (len (self ._aborted_actions )),
416+ "payload_full_actions_count" : 0.0 ,
417+ "payload_prompt_only_actions_count" : 0.0 ,
418+ "payload_full_observations_count" : 0.0 ,
419+ "payload_prompt_only_observations_count" : 0.0 ,
420+ "stored_response_tokens" : 0.0 ,
421+ "stored_response_chars" : 0.0 ,
422+ "stored_versioned_segments" : 0.0 ,
423+ "stored_versioned_tokens" : 0.0 ,
424+ "stored_routed_expert_payloads" : 0.0 ,
425+ "stored_judged_observations" : 0.0 ,
426+ "multimodal_actions_count" : 0.0 ,
427+ }
428+
429+ for replay_meta in self ._actions .values ():
430+ summary = replay_meta .extra_info
431+ observation_count = float (summary .get ("observation_count" , len (replay_meta .observation_ids )))
432+ if summary .get ("payload_mode" , "full" ) == "prompt_only" :
433+ stats ["payload_prompt_only_actions_count" ] += 1.0
434+ stats ["payload_prompt_only_observations_count" ] += observation_count
435+ else :
436+ stats ["payload_full_actions_count" ] += 1.0
437+ stats ["payload_full_observations_count" ] += observation_count
438+
439+ stats ["stored_response_tokens" ] += float (summary .get ("response_tokens" , 0 ))
440+ stats ["stored_response_chars" ] += float (summary .get ("response_chars" , 0 ))
441+ stats ["stored_versioned_segments" ] += float (summary .get ("versioned_segments" , 0 ))
442+ stats ["stored_versioned_tokens" ] += float (summary .get ("versioned_tokens" , 0 ))
443+ stats ["stored_routed_expert_payloads" ] += float (summary .get ("routed_expert_payloads" , 0 ))
444+ stats ["stored_judged_observations" ] += float (summary .get ("judged_observations" , 0 ))
445+ if summary .get ("has_multimodal_prompt" , False ):
446+ stats ["multimodal_actions_count" ] += 1.0
447+
448+ return stats
449+
326450 def add (self , grouped_dataitem : List [RLDataFlowItem ]):
327451 """Adds a group of data items to the storage.
328452
@@ -426,6 +550,8 @@ def sample(self, sample_from_expired_states) -> List[RLDataFlowItem]:
426550 return []
427551
428552 def clear (self ):
553+ for replay_meta in self ._actions .values ():
554+ self ._free_replay_meta_refs (replay_meta )
429555 attrs_to_clear = [
430556 "_aborted_actions" ,
431557 "_completed_actions" ,
@@ -699,6 +825,10 @@ def _check_completed_samples_expired(self):
699825
700826 for version in expired_versions :
701827 bucket = self ._completed_actions .pop (version )
828+ for action_id in bucket :
829+ replay_meta = self ._actions .get (action_id )
830+ if replay_meta is not None :
831+ self ._strip_rollout_payload_for_rerun (replay_meta , RolloutState .EXPIRED )
702832 self ._expired_actions .extend (bucket )
703833 self .logger .info (
704834 f"Moved { len (bucket )} completed samples with version { version } to expired samples due to exceeding tail_batch_candidate_steps."
@@ -709,6 +839,10 @@ def _check_completed_samples_aborted(self):
709839 return
710840
711841 for version , bucket in self ._completed_actions .items ():
842+ for action_id in bucket :
843+ replay_meta = self ._actions .get (action_id )
844+ if replay_meta is not None :
845+ self ._strip_rollout_payload_for_rerun (replay_meta , RolloutState .ABORTED )
712846 self ._aborted_actions [0 ].extend (bucket )
713847 self .logger .info (
714848 f"Moved { len (bucket )} completed samples with version { version } to aborted samples due to partial rollout disabled."
@@ -729,7 +863,9 @@ def _clear_meta_for_actions(self, replay_meta: ReplayMeta):
729863 if state and observation_id in self ._states .get (state , []):
730864 self ._states [state ].remove (observation_id )
731865
866+ self ._actions .pop (action_id , None )
732867 self ._action2observations .pop (action_id , None )
868+ self ._free_replay_meta_refs (replay_meta )
733869 del replay_meta
734870
735871 def _clear_meta_for_root (self , replay_meta : ReplayMeta ):
@@ -747,13 +883,16 @@ def _clear_meta_for_root(self, replay_meta: ReplayMeta):
747883 and clear all related actions.
748884 """
749885 root_id = replay_meta .root_id
886+ current_action_id = replay_meta .action_id
887+ self ._clear_meta_for_actions (replay_meta )
750888 if root_id in self ._root2actions :
751889 for action_id in self ._root2actions [root_id ]:
890+ if action_id == current_action_id :
891+ continue
752892 new_replay_meta = self ._actions .pop (action_id , None )
753893 if new_replay_meta :
754894 self ._clear_meta_for_actions (new_replay_meta )
755895 del self ._root2actions [root_id ]
756- del replay_meta
757896
758897 def _check_rollout_state_and_insert (self , replay_meta : ReplayMeta ):
759898 """Checks the rollout state of a ReplayMeta object and inserts its
@@ -775,11 +914,14 @@ def _check_rollout_state_and_insert(self, replay_meta: ReplayMeta):
775914 if state == RolloutState .ABORTED :
776915 if self .tail_batch_candidate_steps > 0 and replay_meta .version >= self .tail_batch_candidate_steps :
777916 # 过期的数据需要重置状态
917+ self ._strip_rollout_payload_for_rerun (replay_meta , RolloutState .EXPIRED )
778918 self ._expired_actions .append (action_id )
779919 self .logger .debug (
780920 f"Add expired sample with action_id: { action_id } to _expired_actions because version: { replay_meta .version } >= tail_batch_candidate_steps: { self .tail_batch_candidate_steps } ."
781921 )
782922 else :
923+ if not self .enable_partial_rollout :
924+ self ._strip_rollout_payload_for_rerun (replay_meta , RolloutState .ABORTED )
783925 self ._aborted_actions [replay_meta .version ].append (action_id )
784926 self .logger .debug (
785927 f"Add aborted sample with action_id: { action_id } version: { replay_meta .version } to _aborted_actions."
@@ -903,14 +1045,16 @@ def add(self, grouped_dataitem: List[RLDataFlowItem]):
9031045 self .storage .add (grouped_dataitem )
9041046
9051047 def status (self ):
906- return {
1048+ status = {
9071049 "remain_completed_samples_count" : self .storage .completed_samples_count ,
9081050 "remain_aborted_samples_count" : self .storage .aborted_samples_count ,
9091051 "remain_expired_samples_count" : self .storage .expired_samples_count ,
9101052 "sample_from_dataset_count" : self .sample_from_dataset_count ,
9111053 "sample_from_aborted_count" : self .storage .sample_from_aborted_count ,
9121054 "sample_from_expired_count" : self .storage .sample_from_expired_count ,
9131055 }
1056+ status .update (self .storage .get_storage_stats ())
1057+ return status
9141058
9151059 def save (self , file_path : Path | str ):
9161060 """Saves the replay buffer's storage to a file.
0 commit comments