1414# limitations under the License.
1515"""
1616
17+ import asyncio
1718import gc
1819import glob
1920import os
2021import re
2122import time
2223from multiprocessing .shared_memory import SharedMemory
23- from typing import Any , Dict , List
24+ from typing import Any , Dict , Iterable , List , Optional , Tuple
2425
2526import numpy as np
2627import paddle
2728import yaml
2829from paddleformers .utils .log import logger
2930
31+ from fastdeploy import envs
3032from fastdeploy .config import FDConfig
3133from fastdeploy .inter_communicator import KVCacheStatus , ModelWeightsStatus
3234
@@ -52,10 +54,15 @@ def __init__(self, fd_config: FDConfig, models, local_rank: int):
5254 self .model_list = models
5355 self ._capture_model_state ()
5456 self .rdma_handle = None
55- if self .load_config .load_strategy == "rsync" :
56- self .update_weights_by_rdma ()
57+ self .use_gdr_checkpoint_transfer = envs .FD_USE_GDR_CHECKPOINT_TRANSFER
58+
59+ if self .use_gdr_checkpoint_transfer :
60+ self .update_weights_by_gdr ()
5761 else :
58- self .update_parameters ()
62+ if self .load_config .load_strategy == "rsync" :
63+ self .update_weights_by_rdma ()
64+ else :
65+ self .update_parameters ()
5966 self .finalize_update ()
6067
6168 logger .info (
@@ -64,14 +71,20 @@ def __init__(self, fd_config: FDConfig, models, local_rank: int):
6471 )
6572
6673 @paddle .no_grad ()
67- def _capture_model_state (self ):
74+ def _capture_model_state (self , log_params : bool = True ):
6875 """Capture and store initial model parameters state."""
76+ self .state_dict = {}
6977 for model in self .model_list :
7078 for name , param in model .state_dict ().items ():
71- logger .info (f"Model param: { name } , shape={ param .shape } , dtype={ param .dtype } , place={ param .place } " )
79+ if log_params :
80+ logger .info (f"Model param: { name } , shape={ param .shape } , dtype={ param .dtype } , place={ param .place } " )
7281 self .state_dict [name ] = param
7382
74- def update_weights_by_rdma (self , version : str = None , verify_checksum : bool = False ):
83+ def update_weights_by_rdma (
84+ self ,
85+ version : str = None ,
86+ verify_checksum : bool = False ,
87+ ):
7588 def valid_parameters (old_state_dict , new_state_dict ):
7689 is_valid = True
7790 for key in new_state_dict :
@@ -92,14 +105,7 @@ def valid_parameters(old_state_dict, new_state_dict):
92105 )
93106 return is_valid
94107
95- bootstrap_load = version is None or version == ""
96- if bootstrap_load :
97- version = self .read_model_version_from_file ()
98- if version is None or version == "" :
99- raise Exception (
100- "rsync model version not set, please set it in 1) {model_version}/version.yaml "
101- "or 2) interface arguments 'version'"
102- )
108+ version , bootstrap_load = self ._resolve_weight_update_version (version )
103109
104110 logger .info (
105111 f"START rank:{ self .local_rank } /{ self .nranks } update_weights_by_rdma, "
@@ -151,6 +157,164 @@ def valid_parameters(old_state_dict, new_state_dict):
151157 "rank" : self .local_rank ,
152158 }
153159
160+ def update_weights_by_gdr (
161+ self , version : str = None , verify_checksum : bool = False , restore_cleared_params : bool = False
162+ ):
163+ """Unified weight update via CheckpointTransfer (supports GDR and IPC backends)."""
164+ config = dict (self .fd_config .load_config .rsync_config or {})
165+ is_ipc = self .load_config .load_strategy != "rsync"
166+
167+ if is_ipc :
168+ step_id = version or "0"
169+ else :
170+ version , _ = self ._resolve_weight_update_version (version )
171+ step_id = version
172+
173+ logger .info (
174+ f"START rank:{ self .local_rank } /{ self .nranks } update_weights_by_gdr, "
175+ f"load_strategy:{ self .load_config .load_strategy } , step_id:{ step_id } "
176+ )
177+
178+ from checkpoint_transfer .transfer import CheckpointTransfer
179+
180+ transfer_config = self ._build_ct_transfer_config (config )
181+ logger .info (f"CheckpointTransfer config:{ transfer_config } " )
182+ ct_handle = CheckpointTransfer (transfer_config )
183+
184+ total_start = time .perf_counter ()
185+ asyncio .run (ct_handle .initialize ())
186+ try :
187+ weights_iterator = ct_handle .receive_weights_sync (step_id = step_id , output_framework = "paddle" )
188+
189+ if restore_cleared_params :
190+ for name , target_param in self .state_dict .items ():
191+ if not target_param ._is_initialized ():
192+ paddle .empty (target_param .shape , dtype = target_param .dtype )._share_buffer_to (target_param )
193+ logger .debug (f"Restored cleared parameter storage before GDR checkpoint transfer load: { name } " )
194+ update_count , mtp_cache_count = self ._load_models_from_weight_iterator (weights_iterator )
195+ finally :
196+ asyncio .run (ct_handle .cleanup ())
197+ self ._capture_model_state (log_params = False )
198+ total_cost = time .perf_counter () - total_start
199+ logger .info (
200+ f"END update_weights_by_gdr, cost { total_cost :.2f} seconds, "
201+ f"weights:{ update_count } , mtp_cached_weights:{ mtp_cache_count } , "
202+ f"step_id:{ step_id } , local_rank:{ self .local_rank } "
203+ )
204+ return {
205+ "update_cost" : total_cost ,
206+ "total_cost" : total_cost ,
207+ "version" : step_id ,
208+ "rank" : self .local_rank ,
209+ "update_count" : update_count ,
210+ "mtp_cache_count" : mtp_cache_count ,
211+ }
212+
213+ def _build_ct_transfer_config (self , config : dict ):
214+ from dataclasses import fields
215+
216+ from checkpoint_transfer .config import Phase1Backend , Role , TransferConfig
217+
218+ transfer_config = dict (config )
219+ if "device_name" in transfer_config and "device" not in transfer_config :
220+ transfer_config ["device" ] = transfer_config .pop ("device_name" )
221+ else :
222+ transfer_config .pop ("device_name" , None )
223+
224+ transfer_config ["role" ] = Role .INFERENCE
225+
226+ if self .load_config .load_strategy == "rsync" :
227+ node_index = int (transfer_config .pop ("index" , 0 ))
228+ transfer_config ["global_rank" ] = node_index * self .nranks + self .local_rank
229+ transfer_config ["phase1_backend" ] = Phase1Backend .GPU_DIRECT
230+ transfer_config ["group_size" ] = int (transfer_config .get ("group_size" , self .nranks ))
231+ else :
232+ transfer_config .pop ("index" , None )
233+ gpu_id = int (os .getenv ("FLAGS_selected_gpus" , "0" ))
234+ transfer_config ["global_rank" ] = gpu_id
235+ transfer_config ["phase1_backend" ] = Phase1Backend .IPC
236+ transfer_config ["group_size" ] = int (transfer_config .get ("group_size" , self .nranks ))
237+ transfer_config ["qsize" ] = int (transfer_config .get ("qsize" , 2 ))
238+
239+ transfer_config_keys = {field .name for field in fields (TransferConfig )}
240+ transfer_config = {key : value for key , value in transfer_config .items () if key in transfer_config_keys }
241+ return TransferConfig (** transfer_config )
242+
243+ def _resolve_weight_update_version (self , version : Optional [str ]) -> Tuple [str , bool ]:
244+ bootstrap_load = version is None or version == ""
245+ if bootstrap_load :
246+ version = self .read_model_version_from_file ()
247+ if version is None or version == "" :
248+ raise Exception (
249+ "rsync model version not set, please set it in 1) {model_version}/version.yaml "
250+ "or 2) interface arguments 'version'"
251+ )
252+ return version , bootstrap_load
253+
254+ def _load_models_from_weight_iterator (
255+ self ,
256+ weights_iterator : Iterable [Tuple [str , Any ]],
257+ ) -> Tuple [int , int ]:
258+ update_count = 0
259+
260+ if len (self .model_list ) == 1 :
261+
262+ def count_weights ():
263+ nonlocal update_count
264+ for item in weights_iterator :
265+ update_count += 1
266+ yield item
267+
268+ self .model_list [0 ].load_weights (count_weights ())
269+ return update_count , 0
270+
271+ mtp_models = self .model_list [1 :]
272+ config = self .fd_config .load_config .rsync_config or {}
273+ mtp_chunk_size = max (1 , int (config .get ("gdr_mtp_chunk_size" , 16 )))
274+ mtp_chunk : List [Tuple [str , Any ]] = []
275+ mtp_cache_count = 0
276+ mtp_weight_tokens = ["mtp_" , "mtp_block" ]
277+ for model in mtp_models :
278+ model_config = getattr (getattr (model , "fd_config" , None ), "model_config" , None )
279+ start_layer = getattr (model , "mtp_start_layer_idx" , None )
280+ num_layers = getattr (model , "num_mtp_layers" , None )
281+ start_layer = start_layer if start_layer is not None else getattr (model_config , "start_layer_index" , None )
282+ num_layers = (
283+ num_layers if num_layers is not None else getattr (model_config , "num_nextn_predict_layers" , None )
284+ )
285+ if start_layer is None or num_layers is None :
286+ continue
287+ for layer_id in range (int (start_layer ), int (start_layer ) + int (num_layers )):
288+ mtp_weight_tokens .append (f"layers.{ layer_id } ." )
289+ mtp_weight_tokens .append (f".layers.{ layer_id } ." )
290+
291+ def flush_mtp_chunk ():
292+ nonlocal mtp_chunk
293+ if not mtp_chunk :
294+ return
295+ for model in mtp_models :
296+ model .load_weights (iter (mtp_chunk ))
297+ mtp_chunk = []
298+
299+ def cache_mtp_weights ():
300+ nonlocal update_count , mtp_cache_count
301+ for item in weights_iterator :
302+ name , _ = item
303+ update_count += 1
304+ if any (token in name for token in mtp_weight_tokens ):
305+ mtp_chunk .append (item )
306+ mtp_cache_count += 1
307+ yield item
308+ if len (mtp_chunk ) >= mtp_chunk_size :
309+ flush_mtp_chunk ()
310+
311+ self .model_list [0 ].load_weights (cache_mtp_weights ())
312+ flush_mtp_chunk ()
313+ if mtp_cache_count == 0 :
314+ raise ValueError ("No MTP weights were cached from the GDR stream for auxiliary model loading." )
315+
316+ return update_count , mtp_cache_count
317+
154318 def update_parameters (self , pid : int = 0 , restart_process_group = False ) -> None :
155319 """Core method to update model parameters based on strategy."""
156320 start_time = time .perf_counter ()
@@ -414,7 +578,7 @@ def _validate_parameter_match(self, name: str, src: paddle.Tensor, dst: paddle.T
414578 if src .shape != dst .shape :
415579 raise ValueError (f"Shape mismatch for { name } : { src .shape } vs { dst .shape } " )
416580
417- def finalize_update (self , pid : int = 0 ):
581+ def finalize_update (self , pid : Optional [ int ] = None ):
418582 """Finalize update process with verification."""
419583 self ._verify_parameters ("update" )
420584
@@ -479,8 +643,10 @@ def _log_memory(self, context: str):
479643 f"current_reserved: { curr_reserved :.2f} GB"
480644 )
481645
482- def _update_shared_status (self , pid : int , status : int ) -> None :
646+ def _update_shared_status (self , pid : Optional [ int ] , status : int ) -> None :
483647 """Update shared memory status flag for inter-process communication."""
648+ if pid is None :
649+ pid = self .parallel_config .local_engine_worker_queue_port
484650 array = np .zeros ([1 ], dtype = np .int32 )
485651 shm = SharedMemory (create = False , size = array .nbytes , name = f"model_weights_status.{ pid } " )
486652 value = np .ndarray (array .shape , dtype = array .dtype , buffer = shm .buf )
0 commit comments