@@ -91,6 +91,9 @@ def checker_proc(rank: int, device_uuid: str, named_tensors: dict[str, torch.Ten
9191 name : tensor .to (device_manager .device_type ) for name , tensor in named_tensors .items ()
9292 }
9393 _zmq_ctx = zmq .Context ()
94+ mem_info = device_manager .device_module .mem_get_info ()
95+ memory_usage = mem_info [1 ] - mem_info [0 ]
96+ memory_history : list [int ] = [memory_usage ]
9497
9598 def check (names_to_check : dict [str , bool ], weights : list [tuple [str , torch .Tensor ]]):
9699 for name , weight in weights :
@@ -108,6 +111,11 @@ def check_weights(names_to_check: dict[str, bool], socket_paths: list[tuple[str,
108111 run = lambda weights : check (names_to_check , weights ),
109112 post_hook = lambda : device_manager .device_module .synchronize (),
110113 )
114+ device_manager .device_module .synchronize ()
115+ device_manager .device_module .empty_cache ()
116+ mem_info = device_manager .device_module .mem_get_info ()
117+ memory_usage = mem_info [1 ] - mem_info [0 ]
118+ memory_history .append (memory_usage )
111119 assert all (names_to_check .values ())
112120
113121 while True :
@@ -117,6 +125,12 @@ def check_weights(names_to_check: dict[str, bool], socket_paths: list[tuple[str,
117125 names_to_check = dict .fromkeys (named_tensors .keys (), False )
118126 check_weights (names_to_check , socket_paths )
119127
128+ mem_info = device_manager .device_module .mem_get_info ()
129+ memory_usage = mem_info [1 ] - mem_info [0 ]
130+ memory_history .append (memory_usage )
131+ for memory in memory_history [1 :]:
132+ print (f"[rank{ rank } ] Memory change: { memory - memory_history [0 ]} " )
133+
120134
121135def run (
122136 checker_func : callable ,
@@ -318,6 +332,8 @@ def test_update_with_files(test_name: str = "test_with_files"):
318332 rank_list = json .loads (sys .argv [2 ])
319333 if test_type == "test_no_error" :
320334 run (checker_proc , rank_list , need_error = False )
335+ mem_info = device_manager .device_module .mem_get_info ()
336+ print (f"Memory usage: { mem_info [1 ] - mem_info [0 ]} " )
321337 elif test_type == "test_with_remote_error" :
322338 run (
323339 checker_proc_with_error ,
0 commit comments