Skip to content

Commit b69ad2a

Browse files
committed
Reuse GDR checkpoint transfer handle
1 parent e8ae0f9 commit b69ad2a

2 files changed

Lines changed: 51 additions & 15 deletions

File tree

fastdeploy/rl/dynamic_weight_manager.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def __init__(self, fd_config: FDConfig, models, local_rank: int):
5555
self._capture_model_state()
5656
self.rdma_handle = None
5757
self.use_gdr_checkpoint_transfer = envs.FD_USE_GDR_CHECKPOINT_TRANSFER
58+
self._gdr_ct_handle = None
5859

5960
if self.use_gdr_checkpoint_transfer:
6061
self.update_weights_by_gdr()
@@ -175,14 +176,8 @@ def update_weights_by_gdr(
175176
f"load_strategy:{self.load_config.load_strategy}, step_id:{step_id}"
176177
)
177178

178-
from checkpoint_transfer.transfer import CheckpointTransfer
179-
180-
transfer_config = self._build_ct_transfer_config(config)
181-
logger.info(f"CheckpointTransfer config:{transfer_config}")
182-
ct_handle = CheckpointTransfer(transfer_config)
183-
184179
total_start = time.perf_counter()
185-
asyncio.run(ct_handle.initialize())
180+
ct_handle = self._ensure_gdr_handle(config)
186181
try:
187182
weights_iterator = ct_handle.receive_weights_sync(step_id=step_id, output_framework="paddle")
188183

@@ -192,8 +187,9 @@ def update_weights_by_gdr(
192187
paddle.empty(target_param.shape, dtype=target_param.dtype)._share_buffer_to(target_param)
193188
logger.debug(f"Restored cleared parameter storage before GDR checkpoint transfer load: {name}")
194189
update_count, mtp_cache_count = self._load_models_from_weight_iterator(weights_iterator)
195-
finally:
196-
asyncio.run(ct_handle.cleanup())
190+
except Exception:
191+
self._destroy_gdr_handle()
192+
raise
197193
self._capture_model_state(log_params=False)
198194
total_cost = time.perf_counter() - total_start
199195
logger.info(
@@ -210,6 +206,32 @@ def update_weights_by_gdr(
210206
"mtp_cache_count": mtp_cache_count,
211207
}
212208

209+
def _ensure_gdr_handle(self, config: dict):
210+
"""Lazily create and initialize the CheckpointTransfer handle (once)."""
211+
if self._gdr_ct_handle is not None:
212+
return self._gdr_ct_handle
213+
214+
transfer_config = self._build_ct_transfer_config(config)
215+
logger.info(f"CheckpointTransfer config:{transfer_config}")
216+
217+
from checkpoint_transfer.transfer import CheckpointTransfer
218+
219+
ct_handle = CheckpointTransfer(transfer_config)
220+
asyncio.run(ct_handle.initialize())
221+
222+
self._gdr_ct_handle = ct_handle
223+
logger.info("[GDR] CheckpointTransfer initialized and cached for reuse")
224+
return ct_handle
225+
226+
def _destroy_gdr_handle(self):
227+
"""Destroy the cached GDR handle (e.g. on error)."""
228+
if self._gdr_ct_handle is not None:
229+
try:
230+
asyncio.run(self._gdr_ct_handle.cleanup())
231+
except Exception:
232+
pass
233+
self._gdr_ct_handle = None
234+
213235
def _build_ct_transfer_config(self, config: dict):
214236
from dataclasses import fields
215237

tests/rl/test_dynamic_weight_gdr.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ def _make_manager(rsync_config=None, load_strategy="rsync"):
183183
manager.model_list = [_FakeModel()]
184184
manager.state_dict = {}
185185
manager.use_gdr_checkpoint_transfer = True
186+
manager._gdr_ct_handle = None
186187
return manager
187188

188189

@@ -256,25 +257,33 @@ def test_update_weights_by_gdr_gdr_mode(self):
256257
class FakeCheckpointTransfer:
257258
def __init__(self, config):
258259
self.config = config
260+
self.step_ids = []
259261
created.append(self)
260262

261263
def receive_weights_sync(self, step_id, output_framework="paddle"):
262-
self.step_id = step_id
264+
self.step_ids.append(step_id)
263265
self.output_framework = output_framework
264-
yield "model.layers.0.weight", object()
266+
yield f"model.layers.{len(self.step_ids)}.weight", object()
265267

266268
manager = _make_manager()
267269

268270
with _patch_gdr_checkpoint_transfer(FakeCheckpointTransfer):
269271
result = manager.update_weights_by_gdr(version="step-1")
272+
second_result = manager.update_weights_by_gdr(version="step-2")
270273

271274
self.assertEqual(result["version"], "step-1")
275+
self.assertEqual(second_result["version"], "step-2")
272276
self.assertEqual(result["update_count"], 1)
277+
self.assertEqual(second_result["update_count"], 1)
273278
self.assertIn("total_cost", result)
274-
self.assertEqual(manager.model_list[0].loaded[0][0], "model.layers.0.weight")
279+
self.assertEqual(
280+
[name for name, _ in manager.model_list[0].loaded], ["model.layers.1.weight", "model.layers.2.weight"]
281+
)
282+
self.assertEqual(len(created), 1)
283+
self.assertIs(manager._gdr_ct_handle, created[0])
275284
self.assertTrue(created[0].initialized)
276-
self.assertTrue(created[0].cleaned)
277-
self.assertEqual(created[0].step_id, "step-1")
285+
self.assertFalse(hasattr(created[0], "cleaned"))
286+
self.assertEqual(created[0].step_ids, ["step-1", "step-2"])
278287
self.assertEqual(created[0].output_framework, "paddle")
279288
self.assertEqual(created[0].config.kwargs["role"], _FakeRole.INFERENCE)
280289
self.assertEqual(created[0].config.kwargs["phase1_backend"], _FakePhase1Backend.GPU_DIRECT)
@@ -313,9 +322,11 @@ def receive_weights_sync(self, step_id, output_framework="paddle"):
313322
self.assertEqual(created[0].config.kwargs["qsize"], 2)
314323

315324
def test_gdr_checkpoint_transfer_receive_exception_propagates(self):
325+
created = []
326+
316327
class FakeCheckpointTransfer:
317328
def __init__(self, config):
318-
pass
329+
created.append(self)
319330

320331
def receive_weights_sync(self, step_id, output_framework="paddle"):
321332
yield "model.layers.0.weight", object()
@@ -333,6 +344,9 @@ def load_weights(self, weights_iterator):
333344
with self.assertRaisesRegex(RuntimeError, "receive failed"):
334345
manager.update_weights_by_gdr(version="step-error")
335346

347+
self.assertTrue(created[0].cleaned)
348+
self.assertIsNone(manager._gdr_ct_handle)
349+
336350
def test_gdr_checkpoint_transfer_refreshes_state_dict_after_model_loader(self):
337351
loaded_param = object()
338352

0 commit comments

Comments
 (0)