Skip to content

Commit 0a7cf3a

Browse files
committed
Restore RL memory scripts to github/main
1 parent e04db0c commit 0a7cf3a

3 files changed

Lines changed: 23 additions & 243 deletions

File tree

examples/v1/scripts/run_rl.sh

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,6 @@ export LMDEPLOY_LOG_FILE="${WORK_DIR}/lmdeploy_log_${current_time}.txt"
8282
if [ "$ACCELERATOR" = "GPU" ]; then
8383
# TODO: support NPU RL Memory Monitor
8484
export XTUNER_RL_MEM_DIR="${WORK_DIR}/mem_${current_time}"
85-
export XTUNER_RL_MEM_INTERVAL="${XTUNER_RL_MEM_INTERVAL:-60}"
86-
export XTUNER_RL_OBJECT_LIMIT="${XTUNER_RL_OBJECT_LIMIT:-5000}"
87-
export XTUNER_RL_OBJECT_TOP_K="${XTUNER_RL_OBJECT_TOP_K:-10}"
8885
fi
8986

9087
# 2. Launch Ray cluster
@@ -142,4 +139,4 @@ LOG_FILE="${WORK_DIR}/training_log_${current_time}.txt"
142139

143140
python xtuner/v1/train/cli/rl.py \
144141
--config $CONFIG_PATH \
145-
2>&1 | tee -a "${WORK_DIR}/training_log_${current_time}.txt"
142+
2>&1 | tee -a "${WORK_DIR}/training_log_${current_time}.txt"

examples/v1/scripts/run_rl_submit.sh

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,6 @@ export LMDEPLOY_LOG_FILE="${WORK_DIR}/lmdeploy_log_${current_time}.txt"
7373
if [ "$ACCELERATOR" = "GPU" ]; then
7474
# TODO: support NPU RL Memory Monitor
7575
export XTUNER_RL_MEM_DIR="${WORK_DIR}/mem_${current_time}"
76-
export XTUNER_RL_MEM_INTERVAL="${XTUNER_RL_MEM_INTERVAL:-60}"
77-
export XTUNER_RL_OBJECT_LIMIT="${XTUNER_RL_OBJECT_LIMIT:-5000}"
78-
export XTUNER_RL_OBJECT_TOP_K="${XTUNER_RL_OBJECT_TOP_K:-10}"
7976
fi
8077

8178
# 2. Launch Ray cluster
@@ -160,4 +157,4 @@ if [ "$RAY_RANK" -eq 0 ]; then
160157
2>&1 | tee -a "$LOG_FILE"
161158

162159
echo "训练任务提交完成。日志文件: $LOG_FILE"
163-
fi
160+
fi

xtuner/v1/utils/track_rl_mem.py

Lines changed: 21 additions & 235 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
import argparse
2-
import dataclasses
32
import json
43
import os
54
import time
6-
from collections import defaultdict
7-
from typing import Any
85

96
import psutil
107
import ray
@@ -18,163 +15,14 @@
1815
pynvml = None
1916

2017

21-
def _maybe_init_nvml():
18+
def monitor_actor_memory(work_dir: str, interval: int = 60):
2219
if pynvml is None:
23-
return False
24-
try:
25-
pynvml.nvmlInit()
26-
return True
27-
except Exception:
28-
return False
29-
30-
31-
def _maybe_shutdown_nvml(initialized: bool):
32-
if initialized:
33-
try:
34-
pynvml.nvmlShutdown()
35-
except Exception:
36-
pass
37-
38-
39-
def _state_obj_to_dict(obj: Any) -> dict[str, Any]:
40-
if isinstance(obj, dict):
41-
return obj
42-
if dataclasses.is_dataclass(obj):
43-
return dataclasses.asdict(obj)
44-
if hasattr(obj, "model_dump"):
45-
return obj.model_dump()
46-
if hasattr(obj, "dict"):
47-
return obj.dict()
48-
if hasattr(obj, "__dict__"):
49-
return dict(obj.__dict__)
50-
return {}
51-
52-
53-
def _sanitize_tag_component(name: str) -> str:
54-
return name.replace("/", "_").replace(" ", "_").replace(":", "_").replace(".", "_")
55-
56-
57-
def _get_object_store_stats(object_limit: int = 5000, top_k: int = 10):
58-
stats: dict[str, Any] = {
59-
"available": False,
60-
"total_objects": 0,
61-
"total_size_mb": 0.0,
62-
"callsite_enabled": 0,
63-
"summary_by": "",
64-
"ref_type_counts": {},
65-
"task_state_counts": {},
66-
"top_callsites": [],
67-
"detail_truncated": 0,
68-
"detail_object_count": 0,
69-
"detail_total_size_mb": 0.0,
70-
"detail_ref_type_counts": {},
71-
"detail_ref_type_size_mb": {},
72-
"top_call_sites_from_objects": [],
73-
"top_pids": [],
74-
"top_ips": [],
75-
}
76-
77-
try:
78-
from ray.util import state as ray_state
79-
80-
summary_raw = ray_state.summarize_objects(timeout=30, raise_on_missing_output=False)
81-
summary_data = _state_obj_to_dict(summary_raw)
82-
stats["available"] = True
83-
stats["total_objects"] = summary_data.get("total_objects", 0)
84-
stats["total_size_mb"] = float(summary_data.get("total_size_mb", 0.0) or 0.0)
85-
stats["callsite_enabled"] = int(bool(summary_data.get("callsite_enabled", False)))
86-
stats["summary_by"] = summary_data.get("summary_by", "")
87-
88-
ref_type_counts = defaultdict(int)
89-
task_state_counts = defaultdict(int)
90-
callsite_items = []
91-
for callsite, item in (summary_data.get("summary", {}) or {}).items():
92-
item_dict = _state_obj_to_dict(item)
93-
callsite_items.append(
94-
{
95-
"callsite": callsite,
96-
"total_size_mb": float(item_dict.get("total_size_mb", 0.0) or 0.0),
97-
"total_objects": int(item_dict.get("total_objects", 0) or 0),
98-
"total_num_workers": int(item_dict.get("total_num_workers", 0) or 0),
99-
"total_num_nodes": int(item_dict.get("total_num_nodes", 0) or 0),
100-
"ref_type_counts": item_dict.get("ref_type_counts", {}) or {},
101-
"task_state_counts": item_dict.get("task_state_counts", {}) or {},
102-
}
103-
)
104-
for ref_type, count in (item_dict.get("ref_type_counts", {}) or {}).items():
105-
ref_type_counts[str(ref_type)] += int(count)
106-
for task_state, count in (item_dict.get("task_state_counts", {}) or {}).items():
107-
task_state_counts[str(task_state)] += int(count)
108-
109-
callsite_items.sort(key=lambda x: (x["total_size_mb"], x["total_objects"]), reverse=True)
110-
stats["top_callsites"] = callsite_items[:top_k]
111-
stats["ref_type_counts"] = dict(ref_type_counts)
112-
stats["task_state_counts"] = dict(task_state_counts)
113-
114-
try:
115-
object_states = ray_state.list_objects(
116-
limit=object_limit, timeout=30, detail=False, raise_on_missing_output=False
117-
)
118-
pid_size_mb = defaultdict(float)
119-
pid_count = defaultdict(int)
120-
ip_size_mb = defaultdict(float)
121-
ip_count = defaultdict(int)
122-
ref_type_size_mb = defaultdict(float)
123-
ref_type_count = defaultdict(int)
124-
callsite_size_mb = defaultdict(float)
125-
callsite_count = defaultdict(int)
126-
127-
object_state_dicts = [_state_obj_to_dict(obj) for obj in object_states]
128-
stats["detail_object_count"] = len(object_state_dicts)
129-
stats["detail_truncated"] = int(len(object_state_dicts) >= object_limit)
130-
131-
for obj in object_state_dicts:
132-
size_mb = float(obj.get("object_size", 0) or 0) / (1024**2)
133-
pid = str(obj.get("pid", "unknown"))
134-
ip = str(obj.get("ip", "unknown"))
135-
ref_type = str(obj.get("reference_type", "UNKNOWN"))
136-
call_site = str(obj.get("call_site", "unknown"))
137-
138-
stats["detail_total_size_mb"] += size_mb
139-
pid_size_mb[pid] += size_mb
140-
pid_count[pid] += 1
141-
ip_size_mb[ip] += size_mb
142-
ip_count[ip] += 1
143-
ref_type_size_mb[ref_type] += size_mb
144-
ref_type_count[ref_type] += 1
145-
callsite_size_mb[call_site] += size_mb
146-
callsite_count[call_site] += 1
147-
148-
stats["detail_ref_type_counts"] = dict(ref_type_count)
149-
stats["detail_ref_type_size_mb"] = dict(ref_type_size_mb)
150-
stats["top_call_sites_from_objects"] = [
151-
{"callsite": k, "size_mb": v, "count": callsite_count[k]}
152-
for k, v in sorted(callsite_size_mb.items(), key=lambda item: item[1], reverse=True)[:top_k]
153-
]
154-
stats["top_pids"] = [
155-
{"pid": k, "size_mb": v, "count": pid_count[k]}
156-
for k, v in sorted(pid_size_mb.items(), key=lambda item: item[1], reverse=True)[:top_k]
157-
]
158-
stats["top_ips"] = [
159-
{"ip": k, "size_mb": v, "count": ip_count[k]}
160-
for k, v in sorted(ip_size_mb.items(), key=lambda item: item[1], reverse=True)[:top_k]
161-
]
162-
except Exception as e:
163-
stats["detail_error"] = str(e)
164-
165-
except Exception as e:
166-
stats["error"] = str(e)
167-
168-
return stats
169-
170-
171-
def monitor_actor_memory(work_dir: str, interval: int = 60, object_limit: int = 5000, top_k: int = 10):
20+
raise ImportError("pynvml 未安装,无法监控 GPU 内存")
17221

17322
print(f"开始监控 Actor 内存使用情况,间隔 {interval} 秒...")
17423
print("=" * 80)
17524
os.makedirs(f"{work_dir}/tb", exist_ok=True)
176-
actor_f = open(f"{work_dir}/actor_memory.jsonl", "w", encoding="utf-8")
177-
object_f = open(f"{work_dir}/object_store.jsonl", "w", encoding="utf-8")
25+
f = open(f"{work_dir}/actor_memory.json", "w")
17826

17927
cluster_resources = ray.cluster_resources()
18028
total_gpus = int(cluster_resources.get("GPU", 0))
@@ -187,7 +35,6 @@ def monitor_actor_memory(work_dir: str, interval: int = 60, object_limit: int =
18735
while True:
18836
count += 1
18937
memory_info = {}
190-
object_store_info = {}
19138

19239
# 获取所有 Actor
19340
actors = ray.state.actors()
@@ -207,18 +54,17 @@ def monitor_actor_memory(work_dir: str, interval: int = 60, object_limit: int =
20754
try:
20855
process = psutil.Process(pid)
20956
memory_gb = process.memory_info().rss / 1024 / 1024 / 1024
210-
nvml_initialized = _maybe_init_nvml()
211-
if nvml_initialized:
212-
device_count = pynvml.nvmlDeviceGetCount()
213-
for i in range(device_count):
214-
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
215-
# 检查该GPU是否被当前进程使用
216-
compute_procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
217-
if any(proc.pid == pid for proc in compute_procs):
218-
mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
219-
gpu_memory_gb = mem_info.used / 1024 / 1024 / 1024
220-
break
221-
_maybe_shutdown_nvml(nvml_initialized)
57+
pynvml.nvmlInit()
58+
device_count = pynvml.nvmlDeviceGetCount()
59+
for i in range(device_count):
60+
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
61+
# 检查该GPU是否被当前进程使用
62+
compute_procs = pynvml.nvmlDeviceGetComputeRunningProcesses(handle)
63+
if any(proc.pid == pid for proc in compute_procs):
64+
mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
65+
gpu_memory_gb = mem_info.used / 1024 / 1024 / 1024
66+
break
67+
pynvml.nvmlShutdown()
22268

22369
except (psutil.NoSuchProcess, psutil.AccessDenied):
22470
pass
@@ -234,18 +80,10 @@ def monitor_actor_memory(work_dir: str, interval: int = 60, object_limit: int =
23480
"gpu_mem_gb": [gpu_memory_gb],
23581
}
23682

237-
object_store_info = _get_object_store_stats(object_limit=object_limit, top_k=top_k)
238-
object_store_info["time"] = current_time
239-
object_store_info["object_limit"] = object_limit
240-
object_store_info["top_k"] = top_k
241-
24283
# 写入文件
243-
json.dump(memory_info, actor_f, ensure_ascii=False)
244-
actor_f.write("\n")
245-
actor_f.flush()
246-
json.dump(object_store_info, object_f, ensure_ascii=False)
247-
object_f.write("\n")
248-
object_f.flush()
84+
json.dump(memory_info, f, ensure_ascii=False)
85+
f.write("\n")
86+
f.flush()
24987

25088
for actor_name, memory_mb_info in memory_info.items():
25189
if actor_name == "time":
@@ -290,60 +128,13 @@ def monitor_actor_memory(work_dir: str, interval: int = 60, object_limit: int =
290128
global_step=count,
291129
)
292130

293-
tb_writer_list[0].add_scalar(
294-
tag="ray_object_store/total_size_mb",
295-
scalar_value=float(object_store_info.get("total_size_mb", 0.0) or 0.0),
296-
global_step=count,
297-
)
298-
tb_writer_list[0].add_scalar(
299-
tag="ray_object_store/total_objects",
300-
scalar_value=float(object_store_info.get("total_objects", 0) or 0),
301-
global_step=count,
302-
)
303-
tb_writer_list[0].add_scalar(
304-
tag="ray_object_store/detail_total_size_mb",
305-
scalar_value=float(object_store_info.get("detail_total_size_mb", 0.0) or 0.0),
306-
global_step=count,
307-
)
308-
tb_writer_list[0].add_scalar(
309-
tag="ray_object_store/detail_object_count",
310-
scalar_value=float(object_store_info.get("detail_object_count", 0) or 0),
311-
global_step=count,
312-
)
313-
tb_writer_list[0].add_scalar(
314-
tag="ray_object_store/detail_truncated",
315-
scalar_value=float(object_store_info.get("detail_truncated", 0) or 0),
316-
global_step=count,
317-
)
318-
319-
for ref_type, value in (object_store_info.get("ref_type_counts", {}) or {}).items():
320-
tb_writer_list[0].add_scalar(
321-
tag=f"ray_object_store/ref_type_count/{_sanitize_tag_component(str(ref_type))}",
322-
scalar_value=float(value),
323-
global_step=count,
324-
)
325-
for ref_type, value in (object_store_info.get("detail_ref_type_size_mb", {}) or {}).items():
326-
tb_writer_list[0].add_scalar(
327-
tag=f"ray_object_store/ref_type_size_mb/{_sanitize_tag_component(str(ref_type))}",
328-
scalar_value=float(value),
329-
global_step=count,
330-
)
331-
for task_state, value in (object_store_info.get("task_state_counts", {}) or {}).items():
332-
tb_writer_list[0].add_scalar(
333-
tag=f"ray_object_store/task_state_count/{_sanitize_tag_component(str(task_state))}",
334-
scalar_value=float(value),
335-
global_step=count,
336-
)
337-
338131
time.sleep(interval)
339132
print(memory_info)
340-
print(object_store_info)
341133

342134
except KeyboardInterrupt:
343135
print("\n监控已停止")
344136
finally:
345-
actor_f.close()
346-
object_f.close()
137+
f.close()
347138
for tb_writer in tb_writer_list:
348139
tb_writer.close()
349140

@@ -352,24 +143,19 @@ def monitor_actor_memory(work_dir: str, interval: int = 60, object_limit: int =
352143
parser = argparse.ArgumentParser(description="RL MEMORY MONITOR")
353144
parser.add_argument("--work_dir", type=str, default="dense_8b")
354145
parser.add_argument("--interval", type=int, default=60)
355-
parser.add_argument("--object_limit", type=int, default=5000)
356-
parser.add_argument("--top_k", type=int, default=10)
357146
args = parser.parse_args()
358147
work_dir = args.work_dir
359148
interval = args.interval
360-
object_limit = args.object_limit
361-
top_k = args.top_k
362149

363150
while True:
364151
try:
365-
if not ray.is_initialized():
366-
ray.init(address="auto")
367-
time.sleep(interval)
152+
ray.init(address="auto")
153+
time.sleep(interval)
368154
break
369155
except KeyboardInterrupt:
370156
print("\n监控已停止")
371157
break
372158
except Exception:
373159
print("连接 Ray 集群失败, 等等")
374160

375-
monitor_actor_memory(work_dir=work_dir, interval=interval, object_limit=object_limit, top_k=top_k)
161+
monitor_actor_memory(work_dir=work_dir, interval=interval)

0 commit comments

Comments
 (0)