Skip to content

Commit ee3f166

Browse files
committed
Update GDR handle reuse tests
1 parent 17d10c9 commit ee3f166

1 file changed

Lines changed: 20 additions & 6 deletions

File tree

tests/rl/test_dynamic_weight_gdr.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ def _make_manager(rsync_config=None, load_strategy="rsync"):
183183
manager.model_list = [_FakeModel()]
184184
manager.state_dict = {}
185185
manager.use_gdr_checkpoint_transfer = True
186+
manager._gdr_ct_handle = None
186187
return manager
187188

188189

@@ -256,25 +257,33 @@ def test_update_weights_by_gdr_gdr_mode(self):
256257
class FakeCheckpointTransfer:
257258
def __init__(self, config):
258259
self.config = config
260+
self.step_ids = []
259261
created.append(self)
260262

261263
def receive_weights_sync(self, step_id, output_framework="paddle"):
262-
self.step_id = step_id
264+
self.step_ids.append(step_id)
263265
self.output_framework = output_framework
264-
yield "model.layers.0.weight", object()
266+
yield f"model.layers.{len(self.step_ids)}.weight", object()
265267

266268
manager = _make_manager()
267269

268270
with _patch_gdr_checkpoint_transfer(FakeCheckpointTransfer):
269271
result = manager.update_weights_by_gdr(version="step-1")
272+
second_result = manager.update_weights_by_gdr(version="step-2")
270273

271274
self.assertEqual(result["version"], "step-1")
275+
self.assertEqual(second_result["version"], "step-2")
272276
self.assertEqual(result["update_count"], 1)
277+
self.assertEqual(second_result["update_count"], 1)
273278
self.assertIn("total_cost", result)
274-
self.assertEqual(manager.model_list[0].loaded[0][0], "model.layers.0.weight")
279+
self.assertEqual(
280+
[name for name, _ in manager.model_list[0].loaded], ["model.layers.1.weight", "model.layers.2.weight"]
281+
)
282+
self.assertEqual(len(created), 1)
283+
self.assertIs(manager._gdr_ct_handle, created[0])
275284
self.assertTrue(created[0].initialized)
276-
self.assertTrue(created[0].cleaned)
277-
self.assertEqual(created[0].step_id, "step-1")
285+
self.assertFalse(hasattr(created[0], "cleaned"))
286+
self.assertEqual(created[0].step_ids, ["step-1", "step-2"])
278287
self.assertEqual(created[0].output_framework, "paddle")
279288
self.assertEqual(created[0].config.kwargs["role"], _FakeRole.INFERENCE)
280289
self.assertEqual(created[0].config.kwargs["phase1_backend"], _FakePhase1Backend.GPU_DIRECT)
@@ -313,9 +322,11 @@ def receive_weights_sync(self, step_id, output_framework="paddle"):
313322
self.assertEqual(created[0].config.kwargs["qsize"], 2)
314323

315324
def test_gdr_checkpoint_transfer_receive_exception_propagates(self):
325+
created = []
326+
316327
class FakeCheckpointTransfer:
317328
def __init__(self, config):
318-
pass
329+
created.append(self)
319330

320331
def receive_weights_sync(self, step_id, output_framework="paddle"):
321332
yield "model.layers.0.weight", object()
@@ -333,6 +344,9 @@ def load_weights(self, weights_iterator):
333344
with self.assertRaisesRegex(RuntimeError, "receive failed"):
334345
manager.update_weights_by_gdr(version="step-error")
335346

347+
self.assertTrue(created[0].cleaned)
348+
self.assertIsNone(manager._gdr_ct_handle)
349+
336350
def test_gdr_checkpoint_transfer_refreshes_state_dict_after_model_loader(self):
337351
loaded_param = object()
338352

0 commit comments

Comments
 (0)