11import argparse
2- import dataclasses
32import json
43import os
54import time
6- from collections import defaultdict
7- from typing import Any
85
96import psutil
107import ray
1815 pynvml = None
1916
2017
21- def _maybe_init_nvml ( ):
18+ def monitor_actor_memory ( work_dir : str , interval : int = 60 ):
2219 if pynvml is None :
23- return False
24- try :
25- pynvml .nvmlInit ()
26- return True
27- except Exception :
28- return False
29-
30-
31- def _maybe_shutdown_nvml (initialized : bool ):
32- if initialized :
33- try :
34- pynvml .nvmlShutdown ()
35- except Exception :
36- pass
37-
38-
39- def _state_obj_to_dict (obj : Any ) -> dict [str , Any ]:
40- if isinstance (obj , dict ):
41- return obj
42- if dataclasses .is_dataclass (obj ):
43- return dataclasses .asdict (obj )
44- if hasattr (obj , "model_dump" ):
45- return obj .model_dump ()
46- if hasattr (obj , "dict" ):
47- return obj .dict ()
48- if hasattr (obj , "__dict__" ):
49- return dict (obj .__dict__ )
50- return {}
51-
52-
53- def _sanitize_tag_component (name : str ) -> str :
54- return name .replace ("/" , "_" ).replace (" " , "_" ).replace (":" , "_" ).replace ("." , "_" )
55-
56-
57- def _get_object_store_stats (object_limit : int = 5000 , top_k : int = 10 ):
58- stats : dict [str , Any ] = {
59- "available" : False ,
60- "total_objects" : 0 ,
61- "total_size_mb" : 0.0 ,
62- "callsite_enabled" : 0 ,
63- "summary_by" : "" ,
64- "ref_type_counts" : {},
65- "task_state_counts" : {},
66- "top_callsites" : [],
67- "detail_truncated" : 0 ,
68- "detail_object_count" : 0 ,
69- "detail_total_size_mb" : 0.0 ,
70- "detail_ref_type_counts" : {},
71- "detail_ref_type_size_mb" : {},
72- "top_call_sites_from_objects" : [],
73- "top_pids" : [],
74- "top_ips" : [],
75- }
76-
77- try :
78- from ray .util import state as ray_state
79-
80- summary_raw = ray_state .summarize_objects (timeout = 30 , raise_on_missing_output = False )
81- summary_data = _state_obj_to_dict (summary_raw )
82- stats ["available" ] = True
83- stats ["total_objects" ] = summary_data .get ("total_objects" , 0 )
84- stats ["total_size_mb" ] = float (summary_data .get ("total_size_mb" , 0.0 ) or 0.0 )
85- stats ["callsite_enabled" ] = int (bool (summary_data .get ("callsite_enabled" , False )))
86- stats ["summary_by" ] = summary_data .get ("summary_by" , "" )
87-
88- ref_type_counts = defaultdict (int )
89- task_state_counts = defaultdict (int )
90- callsite_items = []
91- for callsite , item in (summary_data .get ("summary" , {}) or {}).items ():
92- item_dict = _state_obj_to_dict (item )
93- callsite_items .append (
94- {
95- "callsite" : callsite ,
96- "total_size_mb" : float (item_dict .get ("total_size_mb" , 0.0 ) or 0.0 ),
97- "total_objects" : int (item_dict .get ("total_objects" , 0 ) or 0 ),
98- "total_num_workers" : int (item_dict .get ("total_num_workers" , 0 ) or 0 ),
99- "total_num_nodes" : int (item_dict .get ("total_num_nodes" , 0 ) or 0 ),
100- "ref_type_counts" : item_dict .get ("ref_type_counts" , {}) or {},
101- "task_state_counts" : item_dict .get ("task_state_counts" , {}) or {},
102- }
103- )
104- for ref_type , count in (item_dict .get ("ref_type_counts" , {}) or {}).items ():
105- ref_type_counts [str (ref_type )] += int (count )
106- for task_state , count in (item_dict .get ("task_state_counts" , {}) or {}).items ():
107- task_state_counts [str (task_state )] += int (count )
108-
109- callsite_items .sort (key = lambda x : (x ["total_size_mb" ], x ["total_objects" ]), reverse = True )
110- stats ["top_callsites" ] = callsite_items [:top_k ]
111- stats ["ref_type_counts" ] = dict (ref_type_counts )
112- stats ["task_state_counts" ] = dict (task_state_counts )
113-
114- try :
115- object_states = ray_state .list_objects (
116- limit = object_limit , timeout = 30 , detail = False , raise_on_missing_output = False
117- )
118- pid_size_mb = defaultdict (float )
119- pid_count = defaultdict (int )
120- ip_size_mb = defaultdict (float )
121- ip_count = defaultdict (int )
122- ref_type_size_mb = defaultdict (float )
123- ref_type_count = defaultdict (int )
124- callsite_size_mb = defaultdict (float )
125- callsite_count = defaultdict (int )
126-
127- object_state_dicts = [_state_obj_to_dict (obj ) for obj in object_states ]
128- stats ["detail_object_count" ] = len (object_state_dicts )
129- stats ["detail_truncated" ] = int (len (object_state_dicts ) >= object_limit )
130-
131- for obj in object_state_dicts :
132- size_mb = float (obj .get ("object_size" , 0 ) or 0 ) / (1024 ** 2 )
133- pid = str (obj .get ("pid" , "unknown" ))
134- ip = str (obj .get ("ip" , "unknown" ))
135- ref_type = str (obj .get ("reference_type" , "UNKNOWN" ))
136- call_site = str (obj .get ("call_site" , "unknown" ))
137-
138- stats ["detail_total_size_mb" ] += size_mb
139- pid_size_mb [pid ] += size_mb
140- pid_count [pid ] += 1
141- ip_size_mb [ip ] += size_mb
142- ip_count [ip ] += 1
143- ref_type_size_mb [ref_type ] += size_mb
144- ref_type_count [ref_type ] += 1
145- callsite_size_mb [call_site ] += size_mb
146- callsite_count [call_site ] += 1
147-
148- stats ["detail_ref_type_counts" ] = dict (ref_type_count )
149- stats ["detail_ref_type_size_mb" ] = dict (ref_type_size_mb )
150- stats ["top_call_sites_from_objects" ] = [
151- {"callsite" : k , "size_mb" : v , "count" : callsite_count [k ]}
152- for k , v in sorted (callsite_size_mb .items (), key = lambda item : item [1 ], reverse = True )[:top_k ]
153- ]
154- stats ["top_pids" ] = [
155- {"pid" : k , "size_mb" : v , "count" : pid_count [k ]}
156- for k , v in sorted (pid_size_mb .items (), key = lambda item : item [1 ], reverse = True )[:top_k ]
157- ]
158- stats ["top_ips" ] = [
159- {"ip" : k , "size_mb" : v , "count" : ip_count [k ]}
160- for k , v in sorted (ip_size_mb .items (), key = lambda item : item [1 ], reverse = True )[:top_k ]
161- ]
162- except Exception as e :
163- stats ["detail_error" ] = str (e )
164-
165- except Exception as e :
166- stats ["error" ] = str (e )
167-
168- return stats
169-
170-
171- def monitor_actor_memory (work_dir : str , interval : int = 60 , object_limit : int = 5000 , top_k : int = 10 ):
20+ raise ImportError ("pynvml 未安装,无法监控 GPU 内存" )
17221
17322 print (f"开始监控 Actor 内存使用情况,间隔 { interval } 秒..." )
17423 print ("=" * 80 )
17524 os .makedirs (f"{ work_dir } /tb" , exist_ok = True )
176- actor_f = open (f"{ work_dir } /actor_memory.jsonl" , "w" , encoding = "utf-8" )
177- object_f = open (f"{ work_dir } /object_store.jsonl" , "w" , encoding = "utf-8" )
25+ f = open (f"{ work_dir } /actor_memory.json" , "w" )
17826
17927 cluster_resources = ray .cluster_resources ()
18028 total_gpus = int (cluster_resources .get ("GPU" , 0 ))
@@ -187,7 +35,6 @@ def monitor_actor_memory(work_dir: str, interval: int = 60, object_limit: int =
18735 while True :
18836 count += 1
18937 memory_info = {}
190- object_store_info = {}
19138
19239 # 获取所有 Actor
19340 actors = ray .state .actors ()
@@ -207,18 +54,17 @@ def monitor_actor_memory(work_dir: str, interval: int = 60, object_limit: int =
20754 try :
20855 process = psutil .Process (pid )
20956 memory_gb = process .memory_info ().rss / 1024 / 1024 / 1024
210- nvml_initialized = _maybe_init_nvml ()
211- if nvml_initialized :
212- device_count = pynvml .nvmlDeviceGetCount ()
213- for i in range (device_count ):
214- handle = pynvml .nvmlDeviceGetHandleByIndex (i )
215- # 检查该GPU是否被当前进程使用
216- compute_procs = pynvml .nvmlDeviceGetComputeRunningProcesses (handle )
217- if any (proc .pid == pid for proc in compute_procs ):
218- mem_info = pynvml .nvmlDeviceGetMemoryInfo (handle )
219- gpu_memory_gb = mem_info .used / 1024 / 1024 / 1024
220- break
221- _maybe_shutdown_nvml (nvml_initialized )
57+ pynvml .nvmlInit ()
58+ device_count = pynvml .nvmlDeviceGetCount ()
59+ for i in range (device_count ):
60+ handle = pynvml .nvmlDeviceGetHandleByIndex (i )
61+ # 检查该GPU是否被当前进程使用
62+ compute_procs = pynvml .nvmlDeviceGetComputeRunningProcesses (handle )
63+ if any (proc .pid == pid for proc in compute_procs ):
64+ mem_info = pynvml .nvmlDeviceGetMemoryInfo (handle )
65+ gpu_memory_gb = mem_info .used / 1024 / 1024 / 1024
66+ break
67+ pynvml .nvmlShutdown ()
22268
22369 except (psutil .NoSuchProcess , psutil .AccessDenied ):
22470 pass
@@ -234,18 +80,10 @@ def monitor_actor_memory(work_dir: str, interval: int = 60, object_limit: int =
23480 "gpu_mem_gb" : [gpu_memory_gb ],
23581 }
23682
237- object_store_info = _get_object_store_stats (object_limit = object_limit , top_k = top_k )
238- object_store_info ["time" ] = current_time
239- object_store_info ["object_limit" ] = object_limit
240- object_store_info ["top_k" ] = top_k
241-
24283 # 写入文件
243- json .dump (memory_info , actor_f , ensure_ascii = False )
244- actor_f .write ("\n " )
245- actor_f .flush ()
246- json .dump (object_store_info , object_f , ensure_ascii = False )
247- object_f .write ("\n " )
248- object_f .flush ()
84+ json .dump (memory_info , f , ensure_ascii = False )
85+ f .write ("\n " )
86+ f .flush ()
24987
25088 for actor_name , memory_mb_info in memory_info .items ():
25189 if actor_name == "time" :
@@ -290,60 +128,13 @@ def monitor_actor_memory(work_dir: str, interval: int = 60, object_limit: int =
290128 global_step = count ,
291129 )
292130
293- tb_writer_list [0 ].add_scalar (
294- tag = "ray_object_store/total_size_mb" ,
295- scalar_value = float (object_store_info .get ("total_size_mb" , 0.0 ) or 0.0 ),
296- global_step = count ,
297- )
298- tb_writer_list [0 ].add_scalar (
299- tag = "ray_object_store/total_objects" ,
300- scalar_value = float (object_store_info .get ("total_objects" , 0 ) or 0 ),
301- global_step = count ,
302- )
303- tb_writer_list [0 ].add_scalar (
304- tag = "ray_object_store/detail_total_size_mb" ,
305- scalar_value = float (object_store_info .get ("detail_total_size_mb" , 0.0 ) or 0.0 ),
306- global_step = count ,
307- )
308- tb_writer_list [0 ].add_scalar (
309- tag = "ray_object_store/detail_object_count" ,
310- scalar_value = float (object_store_info .get ("detail_object_count" , 0 ) or 0 ),
311- global_step = count ,
312- )
313- tb_writer_list [0 ].add_scalar (
314- tag = "ray_object_store/detail_truncated" ,
315- scalar_value = float (object_store_info .get ("detail_truncated" , 0 ) or 0 ),
316- global_step = count ,
317- )
318-
319- for ref_type , value in (object_store_info .get ("ref_type_counts" , {}) or {}).items ():
320- tb_writer_list [0 ].add_scalar (
321- tag = f"ray_object_store/ref_type_count/{ _sanitize_tag_component (str (ref_type ))} " ,
322- scalar_value = float (value ),
323- global_step = count ,
324- )
325- for ref_type , value in (object_store_info .get ("detail_ref_type_size_mb" , {}) or {}).items ():
326- tb_writer_list [0 ].add_scalar (
327- tag = f"ray_object_store/ref_type_size_mb/{ _sanitize_tag_component (str (ref_type ))} " ,
328- scalar_value = float (value ),
329- global_step = count ,
330- )
331- for task_state , value in (object_store_info .get ("task_state_counts" , {}) or {}).items ():
332- tb_writer_list [0 ].add_scalar (
333- tag = f"ray_object_store/task_state_count/{ _sanitize_tag_component (str (task_state ))} " ,
334- scalar_value = float (value ),
335- global_step = count ,
336- )
337-
338131 time .sleep (interval )
339132 print (memory_info )
340- print (object_store_info )
341133
342134 except KeyboardInterrupt :
343135 print ("\n 监控已停止" )
344136 finally :
345- actor_f .close ()
346- object_f .close ()
137+ f .close ()
347138 for tb_writer in tb_writer_list :
348139 tb_writer .close ()
349140
@@ -352,24 +143,19 @@ def monitor_actor_memory(work_dir: str, interval: int = 60, object_limit: int =
352143 parser = argparse .ArgumentParser (description = "RL MEMORY MONITOR" )
353144 parser .add_argument ("--work_dir" , type = str , default = "dense_8b" )
354145 parser .add_argument ("--interval" , type = int , default = 60 )
355- parser .add_argument ("--object_limit" , type = int , default = 5000 )
356- parser .add_argument ("--top_k" , type = int , default = 10 )
357146 args = parser .parse_args ()
358147 work_dir = args .work_dir
359148 interval = args .interval
360- object_limit = args .object_limit
361- top_k = args .top_k
362149
363150 while True :
364151 try :
365- if not ray .is_initialized ():
366- ray .init (address = "auto" )
367- time .sleep (interval )
152+ ray .init (address = "auto" )
153+ time .sleep (interval )
368154 break
369155 except KeyboardInterrupt :
370156 print ("\n 监控已停止" )
371157 break
372158 except Exception :
373159 print ("连接 Ray 集群失败, 等等" )
374160
375- monitor_actor_memory (work_dir = work_dir , interval = interval , object_limit = object_limit , top_k = top_k )
161+ monitor_actor_memory (work_dir = work_dir , interval = interval )
0 commit comments