Skip to content

Commit e04db0c

Browse files
committed
[Enhance] Add memory monitoring enhancements for RL training scripts
- Introduced new environment variables for RL memory monitoring: XTUNER_RL_MEM_INTERVAL, XTUNER_RL_OBJECT_LIMIT, and XTUNER_RL_OBJECT_TOP_K. - Updated run_rl.sh and run_rl_submit.sh to utilize these new variables for configuring memory monitoring. - Enhanced rl_monitor_actor_memory function to accept additional parameters for object limit and top K objects to monitor. - Added a new summarize_group_payload function in replay_buffer.py to provide detailed statistics on grouped data items. - Implemented memory reference management improvements in controller.py and replay_buffer.py to optimize memory usage during training. These changes aim to improve the flexibility and efficiency of memory monitoring in RL training workflows.
1 parent 147cb2e commit e04db0c

6 files changed

Lines changed: 432 additions & 40 deletions

File tree

examples/v1/scripts/run_rl.sh

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,9 @@ export LMDEPLOY_LOG_FILE="${WORK_DIR}/lmdeploy_log_${current_time}.txt"
8282
if [ "$ACCELERATOR" = "GPU" ]; then
8383
# TODO: support NPU RL Memory Monitor
8484
export XTUNER_RL_MEM_DIR="${WORK_DIR}/mem_${current_time}"
85+
export XTUNER_RL_MEM_INTERVAL="${XTUNER_RL_MEM_INTERVAL:-60}"
86+
export XTUNER_RL_OBJECT_LIMIT="${XTUNER_RL_OBJECT_LIMIT:-5000}"
87+
export XTUNER_RL_OBJECT_TOP_K="${XTUNER_RL_OBJECT_TOP_K:-10}"
8588
fi
8689

8790
# 2. Launch Ray cluster
@@ -139,4 +142,4 @@ LOG_FILE="${WORK_DIR}/training_log_${current_time}.txt"
139142

140143
python xtuner/v1/train/cli/rl.py \
141144
--config $CONFIG_PATH \
142-
2>&1 | tee -a "${WORK_DIR}/training_log_${current_time}.txt"
145+
2>&1 | tee -a "${WORK_DIR}/training_log_${current_time}.txt"

examples/v1/scripts/run_rl_submit.sh

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ export LMDEPLOY_LOG_FILE="${WORK_DIR}/lmdeploy_log_${current_time}.txt"
7373
if [ "$ACCELERATOR" = "GPU" ]; then
7474
# TODO: support NPU RL Memory Monitor
7575
export XTUNER_RL_MEM_DIR="${WORK_DIR}/mem_${current_time}"
76+
export XTUNER_RL_MEM_INTERVAL="${XTUNER_RL_MEM_INTERVAL:-60}"
77+
export XTUNER_RL_OBJECT_LIMIT="${XTUNER_RL_OBJECT_LIMIT:-5000}"
78+
export XTUNER_RL_OBJECT_TOP_K="${XTUNER_RL_OBJECT_TOP_K:-10}"
7679
fi
7780

7881
# 2. Launch Ray cluster
@@ -157,4 +160,4 @@ if [ "$RAY_RANK" -eq 0 ]; then
157160
2>&1 | tee -a "$LOG_FILE"
158161

159162
echo "训练任务提交完成。日志文件: $LOG_FILE"
160-
fi
163+
fi

xtuner/v1/ray/dataflow/replay_buffer.py

Lines changed: 147 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,49 @@ class ReplayMeta:
6363
extra_info: Dict[str, Any] = field(default_factory=dict)
6464

6565

66+
def summarize_group_payload(grouped_dataitem: List[RLDataFlowItem]) -> Dict[str, Any]:
67+
summary: Dict[str, Any] = {
68+
"payload_mode": "full",
69+
"observation_count": len(grouped_dataitem),
70+
"response_tokens": 0,
71+
"response_chars": 0,
72+
"versioned_segments": 0,
73+
"versioned_tokens": 0,
74+
"routed_expert_payloads": 0,
75+
"judged_observations": 0,
76+
"has_multimodal_prompt": False,
77+
}
78+
if not grouped_dataitem:
79+
return summary
80+
81+
first_data = grouped_dataitem[0].data
82+
summary["has_multimodal_prompt"] = bool(
83+
getattr(first_data, "multimodal_train_info", None) and len(first_data.multimodal_train_info) > 0
84+
)
85+
86+
for item in grouped_dataitem:
87+
rollout = item.env.rollout
88+
judger = item.env.judger
89+
response_ids = rollout.response_ids or []
90+
response_text = rollout.response or ""
91+
versioned_response_ids = rollout.versioned_response_ids or []
92+
versioned_num_return_tokens = rollout.versioned_num_return_tokens or []
93+
94+
summary["response_tokens"] += len(response_ids)
95+
summary["response_chars"] += len(response_text)
96+
summary["versioned_segments"] += len(versioned_response_ids)
97+
if versioned_num_return_tokens:
98+
summary["versioned_tokens"] += sum(versioned_num_return_tokens)
99+
else:
100+
summary["versioned_tokens"] += sum(len(ids) for ids in versioned_response_ids)
101+
if rollout.extra_info.get("routed_experts", None) is not None:
102+
summary["routed_expert_payloads"] += 1
103+
if judger.uid is not None or judger.reward.get("score", 0.0) != 0.0 or len(judger.extra_info) > 0:
104+
summary["judged_observations"] += 1
105+
106+
return summary
107+
108+
66109
def determine_group_state(group_data_items: List[RLDataFlowItem]) -> RolloutState:
67110
"""Determines the processing strategy for a group of rollout samples based
68111
on their state."""
@@ -113,7 +156,7 @@ def mapping_dataitem_to_replaymeta(grouped_dataitem: List[RLDataFlowItem]) -> Re
113156
observation_refs=observation_refs,
114157
state=group_state,
115158
version=group_version,
116-
extra_info={},
159+
extra_info=summarize_group_payload(grouped_dataitem),
117160
)
118161
return replay_meta
119162

@@ -323,6 +366,87 @@ def __init__(self, replay_buffer_cfg):
323366
self.sample_from_aborted_count = 0
324367
self.sample_from_expired_count = 0
325368

369+
def _free_replay_meta_refs(self, replay_meta: ReplayMeta, include_action_ref: bool = True):
370+
refs = []
371+
if include_action_ref and replay_meta.action_ref is not None:
372+
refs.append(replay_meta.action_ref)
373+
refs.extend([ref for ref in replay_meta.observation_refs if ref is not None])
374+
if refs:
375+
ray.internal.free(refs, local_only=False)
376+
377+
def _update_replay_meta_state(self, replay_meta: ReplayMeta, new_state: RolloutState):
378+
for observation_id in replay_meta.observation_ids:
379+
old_state = self._observations2states.get(observation_id)
380+
if old_state and observation_id in self._states.get(old_state, []):
381+
self._states[old_state].remove(observation_id)
382+
self._observations2states[observation_id] = new_state
383+
if observation_id not in self._states[new_state]:
384+
self._states[new_state].append(observation_id)
385+
replay_meta.state = new_state
386+
387+
def _strip_rollout_payload_for_rerun(self, replay_meta: ReplayMeta, new_state: RolloutState):
388+
"""Keep prompt refs only and drop rollout outputs that will not be reused."""
389+
old_obs_refs = [ref for ref in replay_meta.observation_refs if ref is not None]
390+
if old_obs_refs:
391+
ray.internal.free(old_obs_refs, local_only=False)
392+
replay_meta.observation_refs = [ray.put(RLEnvDataItem()) for _ in replay_meta.observation_ids]
393+
replay_meta.extra_info.update(
394+
{
395+
"payload_mode": "prompt_only",
396+
"response_tokens": 0,
397+
"response_chars": 0,
398+
"versioned_segments": 0,
399+
"versioned_tokens": 0,
400+
"routed_expert_payloads": 0,
401+
"judged_observations": 0,
402+
}
403+
)
404+
self._update_replay_meta_state(replay_meta, new_state)
405+
406+
def get_storage_stats(self) -> Dict[str, float]:
407+
stats: Dict[str, float] = {
408+
"tracked_actions_count": float(len(self._actions)),
409+
"tracked_roots_count": float(len(self._root2actions)),
410+
"tracked_observations_count": float(len(self._observations)),
411+
"completed_actions_count": float(sum(len(bucket) for bucket in self._completed_actions.values())),
412+
"aborted_actions_count": float(sum(len(bucket) for bucket in self._aborted_actions.values())),
413+
"expired_actions_count": float(len(self._expired_actions)),
414+
"completed_versions_count": float(len(self._completed_actions)),
415+
"aborted_versions_count": float(len(self._aborted_actions)),
416+
"payload_full_actions_count": 0.0,
417+
"payload_prompt_only_actions_count": 0.0,
418+
"payload_full_observations_count": 0.0,
419+
"payload_prompt_only_observations_count": 0.0,
420+
"stored_response_tokens": 0.0,
421+
"stored_response_chars": 0.0,
422+
"stored_versioned_segments": 0.0,
423+
"stored_versioned_tokens": 0.0,
424+
"stored_routed_expert_payloads": 0.0,
425+
"stored_judged_observations": 0.0,
426+
"multimodal_actions_count": 0.0,
427+
}
428+
429+
for replay_meta in self._actions.values():
430+
summary = replay_meta.extra_info
431+
observation_count = float(summary.get("observation_count", len(replay_meta.observation_ids)))
432+
if summary.get("payload_mode", "full") == "prompt_only":
433+
stats["payload_prompt_only_actions_count"] += 1.0
434+
stats["payload_prompt_only_observations_count"] += observation_count
435+
else:
436+
stats["payload_full_actions_count"] += 1.0
437+
stats["payload_full_observations_count"] += observation_count
438+
439+
stats["stored_response_tokens"] += float(summary.get("response_tokens", 0))
440+
stats["stored_response_chars"] += float(summary.get("response_chars", 0))
441+
stats["stored_versioned_segments"] += float(summary.get("versioned_segments", 0))
442+
stats["stored_versioned_tokens"] += float(summary.get("versioned_tokens", 0))
443+
stats["stored_routed_expert_payloads"] += float(summary.get("routed_expert_payloads", 0))
444+
stats["stored_judged_observations"] += float(summary.get("judged_observations", 0))
445+
if summary.get("has_multimodal_prompt", False):
446+
stats["multimodal_actions_count"] += 1.0
447+
448+
return stats
449+
326450
def add(self, grouped_dataitem: List[RLDataFlowItem]):
327451
"""Adds a group of data items to the storage.
328452
@@ -426,6 +550,8 @@ def sample(self, sample_from_expired_states) -> List[RLDataFlowItem]:
426550
return []
427551

428552
def clear(self):
553+
for replay_meta in self._actions.values():
554+
self._free_replay_meta_refs(replay_meta)
429555
attrs_to_clear = [
430556
"_aborted_actions",
431557
"_completed_actions",
@@ -699,6 +825,10 @@ def _check_completed_samples_expired(self):
699825

700826
for version in expired_versions:
701827
bucket = self._completed_actions.pop(version)
828+
for action_id in bucket:
829+
replay_meta = self._actions.get(action_id)
830+
if replay_meta is not None:
831+
self._strip_rollout_payload_for_rerun(replay_meta, RolloutState.EXPIRED)
702832
self._expired_actions.extend(bucket)
703833
self.logger.info(
704834
f"Moved {len(bucket)} completed samples with version {version} to expired samples due to exceeding tail_batch_candidate_steps."
@@ -709,6 +839,10 @@ def _check_completed_samples_aborted(self):
709839
return
710840

711841
for version, bucket in self._completed_actions.items():
842+
for action_id in bucket:
843+
replay_meta = self._actions.get(action_id)
844+
if replay_meta is not None:
845+
self._strip_rollout_payload_for_rerun(replay_meta, RolloutState.ABORTED)
712846
self._aborted_actions[0].extend(bucket)
713847
self.logger.info(
714848
f"Moved {len(bucket)} completed samples with version {version} to aborted samples due to partial rollout disabled."
@@ -729,7 +863,9 @@ def _clear_meta_for_actions(self, replay_meta: ReplayMeta):
729863
if state and observation_id in self._states.get(state, []):
730864
self._states[state].remove(observation_id)
731865

866+
self._actions.pop(action_id, None)
732867
self._action2observations.pop(action_id, None)
868+
self._free_replay_meta_refs(replay_meta)
733869
del replay_meta
734870

735871
def _clear_meta_for_root(self, replay_meta: ReplayMeta):
@@ -747,13 +883,16 @@ def _clear_meta_for_root(self, replay_meta: ReplayMeta):
747883
and clear all related actions.
748884
"""
749885
root_id = replay_meta.root_id
886+
current_action_id = replay_meta.action_id
887+
self._clear_meta_for_actions(replay_meta)
750888
if root_id in self._root2actions:
751889
for action_id in self._root2actions[root_id]:
890+
if action_id == current_action_id:
891+
continue
752892
new_replay_meta = self._actions.pop(action_id, None)
753893
if new_replay_meta:
754894
self._clear_meta_for_actions(new_replay_meta)
755895
del self._root2actions[root_id]
756-
del replay_meta
757896

758897
def _check_rollout_state_and_insert(self, replay_meta: ReplayMeta):
759898
"""Checks the rollout state of a ReplayMeta object and inserts its
@@ -775,11 +914,14 @@ def _check_rollout_state_and_insert(self, replay_meta: ReplayMeta):
775914
if state == RolloutState.ABORTED:
776915
if self.tail_batch_candidate_steps > 0 and replay_meta.version >= self.tail_batch_candidate_steps:
777916
# 过期的数据需要重置状态
917+
self._strip_rollout_payload_for_rerun(replay_meta, RolloutState.EXPIRED)
778918
self._expired_actions.append(action_id)
779919
self.logger.debug(
780920
f"Add expired sample with action_id: {action_id} to _expired_actions because version: {replay_meta.version} >= tail_batch_candidate_steps: {self.tail_batch_candidate_steps}."
781921
)
782922
else:
923+
if not self.enable_partial_rollout:
924+
self._strip_rollout_payload_for_rerun(replay_meta, RolloutState.ABORTED)
783925
self._aborted_actions[replay_meta.version].append(action_id)
784926
self.logger.debug(
785927
f"Add aborted sample with action_id: {action_id} version: {replay_meta.version} to _aborted_actions."
@@ -903,14 +1045,16 @@ def add(self, grouped_dataitem: List[RLDataFlowItem]):
9031045
self.storage.add(grouped_dataitem)
9041046

9051047
def status(self):
906-
return {
1048+
status = {
9071049
"remain_completed_samples_count": self.storage.completed_samples_count,
9081050
"remain_aborted_samples_count": self.storage.aborted_samples_count,
9091051
"remain_expired_samples_count": self.storage.expired_samples_count,
9101052
"sample_from_dataset_count": self.sample_from_dataset_count,
9111053
"sample_from_aborted_count": self.storage.sample_from_aborted_count,
9121054
"sample_from_expired_count": self.storage.sample_from_expired_count,
9131055
}
1056+
status.update(self.storage.get_storage_stats())
1057+
return status
9141058

9151059
def save(self, file_path: Path | str):
9161060
"""Saves the replay buffer's storage to a file.

xtuner/v1/rl/base/controller.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import ray
66
import torch
7+
from ray import ObjectRef
78
from ray.actor import ActorProxy
89

910
from xtuner.v1.data_proto.sequence_context import SequenceContext
@@ -28,6 +29,23 @@ class RawTrainingController:
2829
def __init__(self, workers: list[TrainingWorker]) -> None:
2930
self.workers = workers
3031

32+
def _collect_object_refs(self, obj, refs: list[ObjectRef]):
33+
if isinstance(obj, ObjectRef):
34+
refs.append(obj)
35+
return
36+
if isinstance(obj, (list, tuple)):
37+
for item in obj:
38+
self._collect_object_refs(item, refs)
39+
40+
def _free_batch_object_refs(self, data_batches):
41+
refs: list[ObjectRef] = []
42+
for data in data_batches:
43+
seq_ctx = data["seq_ctx"]
44+
self._collect_object_refs(seq_ctx.pixel_values, refs)
45+
self._collect_object_refs(seq_ctx.rollout_routed_experts, refs)
46+
if refs:
47+
ray.internal.free(refs, local_only=False)
48+
3149
# TODO(hha): 这个逻辑不够通用,应该复用 sft 函数,从而支持 expand soft pack
3250
def _get_pack_infos(self, dataset, num_tokens, target, random=None):
3351
inds = list(range(len(dataset)))
@@ -260,7 +278,10 @@ def fit(self, data_batches: list[ColateItem], pack_max_length: int, rollout_idx:
260278
rollout_idx=rollout_idx,
261279
)
262280
)
263-
log_infos = ray.get(handles, timeout=TRAIN_RAY_GET_TIMEOUT)
281+
try:
282+
log_infos = ray.get(handles, timeout=TRAIN_RAY_GET_TIMEOUT)
283+
finally:
284+
self._free_batch_object_refs(packed_data_batches)
264285
return log_infos
265286

266287
@ray_method

xtuner/v1/train/cli/rl.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,20 @@
2020
)
2121

2222

23-
def rl_monitor_actor_memory(work_dir, interval: int = 60):
24-
while True:
25-
try:
26-
ray.init(address="auto")
27-
time.sleep(interval)
28-
break
29-
except KeyboardInterrupt:
30-
print("\n监控已停止")
31-
break
32-
except Exception:
33-
print("连接 Ray 集群失败, 等等")
23+
def rl_monitor_actor_memory(work_dir, interval: int = 60, object_limit: int = 5000, top_k: int = 10):
24+
if not ray.is_initialized():
25+
while True:
26+
try:
27+
ray.init(address="auto")
28+
time.sleep(interval)
29+
break
30+
except KeyboardInterrupt:
31+
print("\n监控已停止")
32+
return
33+
except Exception:
34+
print("连接 Ray 集群失败, 等等")
3435

35-
monitor_actor_memory(work_dir=work_dir, interval=interval)
36+
monitor_actor_memory(work_dir=work_dir, interval=interval, object_limit=object_limit, top_k=top_k)
3637

3738

3839
@app.default()
@@ -51,7 +52,13 @@ def main(
5152

5253
if os.getenv("XTUNER_RL_MEM_DIR"):
5354
print("Start to monitor actor memory")
54-
track_thread = threading.Thread(target=rl_monitor_actor_memory, args=(os.getenv("XTUNER_RL_MEM_DIR"),))
55+
monitor_interval = int(os.getenv("XTUNER_RL_MEM_INTERVAL", "60"))
56+
object_limit = int(os.getenv("XTUNER_RL_OBJECT_LIMIT", "5000"))
57+
top_k = int(os.getenv("XTUNER_RL_OBJECT_TOP_K", "10"))
58+
track_thread = threading.Thread(
59+
target=rl_monitor_actor_memory,
60+
args=(os.getenv("XTUNER_RL_MEM_DIR"), monitor_interval, object_limit, top_k),
61+
)
5562
track_thread.daemon = True
5663
track_thread.start()
5764

0 commit comments

Comments
 (0)