@@ -183,6 +183,7 @@ def _make_manager(rsync_config=None, load_strategy="rsync"):
183183 manager .model_list = [_FakeModel ()]
184184 manager .state_dict = {}
185185 manager .use_gdr_checkpoint_transfer = True
186+ manager ._gdr_ct_handle = None
186187 return manager
187188
188189
@@ -256,25 +257,33 @@ def test_update_weights_by_gdr_gdr_mode(self):
256257 class FakeCheckpointTransfer :
257258 def __init__ (self , config ):
258259 self .config = config
260+ self .step_ids = []
259261 created .append (self )
260262
261263 def receive_weights_sync (self , step_id , output_framework = "paddle" ):
262- self .step_id = step_id
264+ self .step_ids . append ( step_id )
263265 self .output_framework = output_framework
264- yield "model.layers.0 .weight" , object ()
266+ yield f "model.layers.{ len ( self . step_ids ) } .weight" , object ()
265267
266268 manager = _make_manager ()
267269
268270 with _patch_gdr_checkpoint_transfer (FakeCheckpointTransfer ):
269271 result = manager .update_weights_by_gdr (version = "step-1" )
272+ second_result = manager .update_weights_by_gdr (version = "step-2" )
270273
271274 self .assertEqual (result ["version" ], "step-1" )
275+ self .assertEqual (second_result ["version" ], "step-2" )
272276 self .assertEqual (result ["update_count" ], 1 )
277+ self .assertEqual (second_result ["update_count" ], 1 )
273278 self .assertIn ("total_cost" , result )
274- self .assertEqual (manager .model_list [0 ].loaded [0 ][0 ], "model.layers.0.weight" )
279+ self .assertEqual (
280+ [name for name , _ in manager .model_list [0 ].loaded ], ["model.layers.1.weight" , "model.layers.2.weight" ]
281+ )
282+ self .assertEqual (len (created ), 1 )
283+ self .assertIs (manager ._gdr_ct_handle , created [0 ])
275284 self .assertTrue (created [0 ].initialized )
276- self .assertTrue ( created [0 ]. cleaned )
277- self .assertEqual (created [0 ].step_id , "step-1" )
285+ self .assertFalse ( hasattr ( created [0 ], " cleaned" ) )
286+ self .assertEqual (created [0 ].step_ids , [ "step-1" , "step-2" ] )
278287 self .assertEqual (created [0 ].output_framework , "paddle" )
279288 self .assertEqual (created [0 ].config .kwargs ["role" ], _FakeRole .INFERENCE )
280289 self .assertEqual (created [0 ].config .kwargs ["phase1_backend" ], _FakePhase1Backend .GPU_DIRECT )
@@ -313,9 +322,11 @@ def receive_weights_sync(self, step_id, output_framework="paddle"):
313322 self .assertEqual (created [0 ].config .kwargs ["qsize" ], 2 )
314323
315324 def test_gdr_checkpoint_transfer_receive_exception_propagates (self ):
325+ created = []
326+
316327 class FakeCheckpointTransfer :
317328 def __init__ (self , config ):
318- pass
329+ created . append ( self )
319330
320331 def receive_weights_sync (self , step_id , output_framework = "paddle" ):
321332 yield "model.layers.0.weight" , object ()
@@ -333,6 +344,9 @@ def load_weights(self, weights_iterator):
333344 with self .assertRaisesRegex (RuntimeError , "receive failed" ):
334345 manager .update_weights_by_gdr (version = "step-error" )
335346
347+ self .assertTrue (created [0 ].cleaned )
348+ self .assertIsNone (manager ._gdr_ct_handle )
349+
336350 def test_gdr_checkpoint_transfer_refreshes_state_dict_after_model_loader (self ):
337351 loaded_param = object ()
338352
0 commit comments