Skip to content

Commit 245a5e0

Browse files
authored
[RL][Feature] Add GDR streaming weight update path (#7951)
* Add GDR streaming weight update path * [RL] Unify GDR and IPC weight update
1 parent cc9af4d commit 245a5e0

8 files changed

Lines changed: 1020 additions & 22 deletions

File tree

fastdeploy/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1456,6 +1456,8 @@ def __init__(
14561456
self.model_loader_extra_config: Optional[Dict[str, Any]] = None
14571457
for key, value in args.items():
14581458
if hasattr(self, key):
1459+
if key == "rsync_config" and isinstance(value, str):
1460+
value = json.loads(value)
14591461
setattr(self, key, value)
14601462

14611463
def __str__(self) -> str:

fastdeploy/envs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,8 @@ def _validate_split_kv_size(value: int) -> int:
263263
"FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST": lambda: bool(
264264
int(os.getenv("FD_SAVE_OUTPUT_CACHE_FOR_PREEMPTED_REQUEST", "1"))
265265
),
266+
# Whether to use GDR CheckpointTransfer for dynamic weight updates.
267+
"FD_USE_GDR_CHECKPOINT_TRANSFER": lambda: bool(int(os.getenv("FD_USE_GDR_CHECKPOINT_TRANSFER", "0"))),
266268
# Whether to enable block-wise CUDA Graph capture/replay.
267269
# When enabled, individual layer forward methods decorated with @block_wise_cuda_graph_wrap
268270
# will be captured and replayed as CUDA Graphs for improved performance.

fastdeploy/model_executor/utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,35 @@ def slice_fn(weight_or_parameter, output_dim, start, end, step=1):
131131
return weight_or_parameter
132132

133133

134+
def _is_gdr_checkpoint_transfer_dynamic_load_config(fd_config: FDConfig) -> bool:
135+
load_config = fd_config.load_config
136+
if not load_config.dynamic_load_weight:
137+
return False
138+
return envs.FD_USE_GDR_CHECKPOINT_TRANSFER
139+
140+
141+
def _copy_gdr_checkpoint_transfer_transposed_weight_attrs(src, dst):
142+
attr_names = (
143+
"weight_loader",
144+
"output_dim",
145+
"weight_need_transpose",
146+
"is_distributed",
147+
"split_axis",
148+
"tp_row_bias",
149+
)
150+
for name in attr_names:
151+
if hasattr(src, name):
152+
setattr(dst, name, getattr(src, name))
153+
if hasattr(src, "output_dim") and src.output_dim is not None:
154+
dst.output_dim = not src.output_dim
155+
dst.weight_need_transpose = not getattr(src, "weight_need_transpose", False)
156+
if hasattr(src, "split_axis"):
157+
if len(src.shape) == 2 and src.split_axis in (0, 1):
158+
dst.split_axis = 1 - src.split_axis
159+
elif len(src.shape) == 3 and src.split_axis in (1, 2):
160+
dst.split_axis = 3 - src.split_axis
161+
162+
134163
def process_weight_transpose(layer, weight_name):
135164
weight = getattr(layer, weight_name)
136165
if len(weight.shape) == 2:
@@ -143,6 +172,8 @@ def process_weight_transpose(layer, weight_name):
143172
default_initializer=paddle.nn.initializer.Constant(0),
144173
is_bias=False,
145174
)
175+
if _is_gdr_checkpoint_transfer_dynamic_load_config(layer.fd_config):
176+
_copy_gdr_checkpoint_transfer_transposed_weight_attrs(weight, weight_tmp)
146177
if layer.fd_config.load_config.dynamic_load_weight or getattr(layer.fd_config.model_config, "enable_cache", False):
147178
free_tensor(weight)
148179
setattr(layer, weight_name, weight_tmp)
@@ -361,6 +392,8 @@ def fn(param, loaded_weight, shard_id: Optional[Union[int, str]] = None):
361392
f" Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({param.shape})"
362393
)
363394
loaded_weight = get_tensor(loaded_weight)
395+
if not param._is_initialized():
396+
param.initialize()
364397
param.copy_(loaded_weight, False)
365398

366399
return fn

fastdeploy/rl/dynamic_weight_manager.py

Lines changed: 183 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,21 @@
1414
# limitations under the License.
1515
"""
1616

17+
import asyncio
1718
import gc
1819
import glob
1920
import os
2021
import re
2122
import time
2223
from multiprocessing.shared_memory import SharedMemory
23-
from typing import Any, Dict, List
24+
from typing import Any, Dict, Iterable, List, Optional, Tuple
2425

2526
import numpy as np
2627
import paddle
2728
import yaml
2829
from paddleformers.utils.log import logger
2930

31+
from fastdeploy import envs
3032
from fastdeploy.config import FDConfig
3133
from fastdeploy.inter_communicator import KVCacheStatus, ModelWeightsStatus
3234

@@ -52,10 +54,15 @@ def __init__(self, fd_config: FDConfig, models, local_rank: int):
5254
self.model_list = models
5355
self._capture_model_state()
5456
self.rdma_handle = None
55-
if self.load_config.load_strategy == "rsync":
56-
self.update_weights_by_rdma()
57+
self.use_gdr_checkpoint_transfer = envs.FD_USE_GDR_CHECKPOINT_TRANSFER
58+
59+
if self.use_gdr_checkpoint_transfer:
60+
self.update_weights_by_gdr()
5761
else:
58-
self.update_parameters()
62+
if self.load_config.load_strategy == "rsync":
63+
self.update_weights_by_rdma()
64+
else:
65+
self.update_parameters()
5966
self.finalize_update()
6067

6168
logger.info(
@@ -64,14 +71,20 @@ def __init__(self, fd_config: FDConfig, models, local_rank: int):
6471
)
6572

6673
@paddle.no_grad()
67-
def _capture_model_state(self):
74+
def _capture_model_state(self, log_params: bool = True):
6875
"""Capture and store initial model parameters state."""
76+
self.state_dict = {}
6977
for model in self.model_list:
7078
for name, param in model.state_dict().items():
71-
logger.info(f"Model param: {name}, shape={param.shape}, dtype={param.dtype}, place={param.place}")
79+
if log_params:
80+
logger.info(f"Model param: {name}, shape={param.shape}, dtype={param.dtype}, place={param.place}")
7281
self.state_dict[name] = param
7382

74-
def update_weights_by_rdma(self, version: str = None, verify_checksum: bool = False):
83+
def update_weights_by_rdma(
84+
self,
85+
version: str = None,
86+
verify_checksum: bool = False,
87+
):
7588
def valid_parameters(old_state_dict, new_state_dict):
7689
is_valid = True
7790
for key in new_state_dict:
@@ -92,14 +105,7 @@ def valid_parameters(old_state_dict, new_state_dict):
92105
)
93106
return is_valid
94107

95-
bootstrap_load = version is None or version == ""
96-
if bootstrap_load:
97-
version = self.read_model_version_from_file()
98-
if version is None or version == "":
99-
raise Exception(
100-
"rsync model version not set, please set it in 1) {model_version}/version.yaml "
101-
"or 2) interface arguments 'version'"
102-
)
108+
version, bootstrap_load = self._resolve_weight_update_version(version)
103109

104110
logger.info(
105111
f"START rank:{self.local_rank}/{self.nranks} update_weights_by_rdma, "
@@ -151,6 +157,164 @@ def valid_parameters(old_state_dict, new_state_dict):
151157
"rank": self.local_rank,
152158
}
153159

160+
def update_weights_by_gdr(
161+
self, version: str = None, verify_checksum: bool = False, restore_cleared_params: bool = False
162+
):
163+
"""Unified weight update via CheckpointTransfer (supports GDR and IPC backends)."""
164+
config = dict(self.fd_config.load_config.rsync_config or {})
165+
is_ipc = self.load_config.load_strategy != "rsync"
166+
167+
if is_ipc:
168+
step_id = version or "0"
169+
else:
170+
version, _ = self._resolve_weight_update_version(version)
171+
step_id = version
172+
173+
logger.info(
174+
f"START rank:{self.local_rank}/{self.nranks} update_weights_by_gdr, "
175+
f"load_strategy:{self.load_config.load_strategy}, step_id:{step_id}"
176+
)
177+
178+
from checkpoint_transfer.transfer import CheckpointTransfer
179+
180+
transfer_config = self._build_ct_transfer_config(config)
181+
logger.info(f"CheckpointTransfer config:{transfer_config}")
182+
ct_handle = CheckpointTransfer(transfer_config)
183+
184+
total_start = time.perf_counter()
185+
asyncio.run(ct_handle.initialize())
186+
try:
187+
weights_iterator = ct_handle.receive_weights_sync(step_id=step_id, output_framework="paddle")
188+
189+
if restore_cleared_params:
190+
for name, target_param in self.state_dict.items():
191+
if not target_param._is_initialized():
192+
paddle.empty(target_param.shape, dtype=target_param.dtype)._share_buffer_to(target_param)
193+
logger.debug(f"Restored cleared parameter storage before GDR checkpoint transfer load: {name}")
194+
update_count, mtp_cache_count = self._load_models_from_weight_iterator(weights_iterator)
195+
finally:
196+
asyncio.run(ct_handle.cleanup())
197+
self._capture_model_state(log_params=False)
198+
total_cost = time.perf_counter() - total_start
199+
logger.info(
200+
f"END update_weights_by_gdr, cost {total_cost:.2f} seconds, "
201+
f"weights:{update_count}, mtp_cached_weights:{mtp_cache_count}, "
202+
f"step_id:{step_id}, local_rank:{self.local_rank}"
203+
)
204+
return {
205+
"update_cost": total_cost,
206+
"total_cost": total_cost,
207+
"version": step_id,
208+
"rank": self.local_rank,
209+
"update_count": update_count,
210+
"mtp_cache_count": mtp_cache_count,
211+
}
212+
213+
def _build_ct_transfer_config(self, config: dict):
214+
from dataclasses import fields
215+
216+
from checkpoint_transfer.config import Phase1Backend, Role, TransferConfig
217+
218+
transfer_config = dict(config)
219+
if "device_name" in transfer_config and "device" not in transfer_config:
220+
transfer_config["device"] = transfer_config.pop("device_name")
221+
else:
222+
transfer_config.pop("device_name", None)
223+
224+
transfer_config["role"] = Role.INFERENCE
225+
226+
if self.load_config.load_strategy == "rsync":
227+
node_index = int(transfer_config.pop("index", 0))
228+
transfer_config["global_rank"] = node_index * self.nranks + self.local_rank
229+
transfer_config["phase1_backend"] = Phase1Backend.GPU_DIRECT
230+
transfer_config["group_size"] = int(transfer_config.get("group_size", self.nranks))
231+
else:
232+
transfer_config.pop("index", None)
233+
gpu_id = int(os.getenv("FLAGS_selected_gpus", "0"))
234+
transfer_config["global_rank"] = gpu_id
235+
transfer_config["phase1_backend"] = Phase1Backend.IPC
236+
transfer_config["group_size"] = int(transfer_config.get("group_size", self.nranks))
237+
transfer_config["qsize"] = int(transfer_config.get("qsize", 2))
238+
239+
transfer_config_keys = {field.name for field in fields(TransferConfig)}
240+
transfer_config = {key: value for key, value in transfer_config.items() if key in transfer_config_keys}
241+
return TransferConfig(**transfer_config)
242+
243+
def _resolve_weight_update_version(self, version: Optional[str]) -> Tuple[str, bool]:
244+
bootstrap_load = version is None or version == ""
245+
if bootstrap_load:
246+
version = self.read_model_version_from_file()
247+
if version is None or version == "":
248+
raise Exception(
249+
"rsync model version not set, please set it in 1) {model_version}/version.yaml "
250+
"or 2) interface arguments 'version'"
251+
)
252+
return version, bootstrap_load
253+
254+
def _load_models_from_weight_iterator(
255+
self,
256+
weights_iterator: Iterable[Tuple[str, Any]],
257+
) -> Tuple[int, int]:
258+
update_count = 0
259+
260+
if len(self.model_list) == 1:
261+
262+
def count_weights():
263+
nonlocal update_count
264+
for item in weights_iterator:
265+
update_count += 1
266+
yield item
267+
268+
self.model_list[0].load_weights(count_weights())
269+
return update_count, 0
270+
271+
mtp_models = self.model_list[1:]
272+
config = self.fd_config.load_config.rsync_config or {}
273+
mtp_chunk_size = max(1, int(config.get("gdr_mtp_chunk_size", 16)))
274+
mtp_chunk: List[Tuple[str, Any]] = []
275+
mtp_cache_count = 0
276+
mtp_weight_tokens = ["mtp_", "mtp_block"]
277+
for model in mtp_models:
278+
model_config = getattr(getattr(model, "fd_config", None), "model_config", None)
279+
start_layer = getattr(model, "mtp_start_layer_idx", None)
280+
num_layers = getattr(model, "num_mtp_layers", None)
281+
start_layer = start_layer if start_layer is not None else getattr(model_config, "start_layer_index", None)
282+
num_layers = (
283+
num_layers if num_layers is not None else getattr(model_config, "num_nextn_predict_layers", None)
284+
)
285+
if start_layer is None or num_layers is None:
286+
continue
287+
for layer_id in range(int(start_layer), int(start_layer) + int(num_layers)):
288+
mtp_weight_tokens.append(f"layers.{layer_id}.")
289+
mtp_weight_tokens.append(f".layers.{layer_id}.")
290+
291+
def flush_mtp_chunk():
292+
nonlocal mtp_chunk
293+
if not mtp_chunk:
294+
return
295+
for model in mtp_models:
296+
model.load_weights(iter(mtp_chunk))
297+
mtp_chunk = []
298+
299+
def cache_mtp_weights():
300+
nonlocal update_count, mtp_cache_count
301+
for item in weights_iterator:
302+
name, _ = item
303+
update_count += 1
304+
if any(token in name for token in mtp_weight_tokens):
305+
mtp_chunk.append(item)
306+
mtp_cache_count += 1
307+
yield item
308+
if len(mtp_chunk) >= mtp_chunk_size:
309+
flush_mtp_chunk()
310+
311+
self.model_list[0].load_weights(cache_mtp_weights())
312+
flush_mtp_chunk()
313+
if mtp_cache_count == 0:
314+
raise ValueError("No MTP weights were cached from the GDR stream for auxiliary model loading.")
315+
316+
return update_count, mtp_cache_count
317+
154318
def update_parameters(self, pid: int = 0, restart_process_group=False) -> None:
155319
"""Core method to update model parameters based on strategy."""
156320
start_time = time.perf_counter()
@@ -414,7 +578,7 @@ def _validate_parameter_match(self, name: str, src: paddle.Tensor, dst: paddle.T
414578
if src.shape != dst.shape:
415579
raise ValueError(f"Shape mismatch for {name}: {src.shape} vs {dst.shape}")
416580

417-
def finalize_update(self, pid: int = 0):
581+
def finalize_update(self, pid: Optional[int] = None):
418582
"""Finalize update process with verification."""
419583
self._verify_parameters("update")
420584

@@ -479,8 +643,10 @@ def _log_memory(self, context: str):
479643
f"current_reserved: {curr_reserved:.2f}GB"
480644
)
481645

482-
def _update_shared_status(self, pid: int, status: int) -> None:
646+
def _update_shared_status(self, pid: Optional[int], status: int) -> None:
483647
"""Update shared memory status flag for inter-process communication."""
648+
if pid is None:
649+
pid = self.parallel_config.local_engine_worker_queue_port
484650
array = np.zeros([1], dtype=np.int32)
485651
shm = SharedMemory(create=False, size=array.nbytes, name=f"model_weights_status.{pid}")
486652
value = np.ndarray(array.shape, dtype=array.dtype, buffer=shm.buf)

0 commit comments

Comments
 (0)