@@ -279,9 +279,22 @@ def _restack_from_views(self) -> None:
279279 device = torch .cuda .current_device ()
280280 for attr_name in ("experts_gate_up_weight" , "experts_down_weight" ):
281281 old_param = getattr (self , attr_name )
282- new_data = torch .empty_like (old_param , device = device )
283- torch .nn .init .normal_ (new_data , mean = 0.0 , std = self .initializer_range )
284- setattr (self , attr_name , nn .Parameter (new_data ))
282+ if isinstance (old_param .data , DTensor ):
283+ # FSDP2 has sharded this param; materialize the local shard on CUDA
284+ # and reconstruct the DTensor wrapper so FSDP2 can manage it.
285+ local_data = old_param .data .to_local ()
286+ new_local = torch .empty (local_data .shape , dtype = local_data .dtype , device = device )
287+ torch .nn .init .normal_ (new_local , mean = 0.0 , std = self .initializer_range )
288+ new_dtensor = DTensor .from_local (
289+ new_local ,
290+ device_mesh = old_param .data .device_mesh ,
291+ placements = old_param .data .placements ,
292+ )
293+ setattr (self , attr_name , nn .Parameter (new_dtensor ))
294+ else :
295+ new_data = torch .empty_like (old_param , device = device )
296+ torch .nn .init .normal_ (new_data , mean = 0.0 , std = self .initializer_range )
297+ setattr (self , attr_name , nn .Parameter (new_data ))
285298
286299 # Re-sync views to point to the new stacked parameter
287300 self ._sync_expert_views ()
@@ -298,13 +311,15 @@ def _sync_expert_views(self) -> None:
298311 gate_up_w = self .experts_gate_up_weight
299312 if isinstance (gate_up_w , DTensor ):
300313 gate_up_w = gate_up_w .to_local ()
301- for i in range (self .num_local_experts ):
314+ num_local = gate_up_w .shape [0 ]
315+ for i in range (num_local ):
302316 object .__setattr__ (self .experts_gate_up , f"weight{ i } " , gate_up_w [i ])
303317
304318 down_w = self .experts_down_weight
305319 if isinstance (down_w , DTensor ):
306320 down_w = down_w .to_local ()
307- for i in range (self .num_local_experts ):
321+ num_local_down = down_w .shape [0 ]
322+ for i in range (num_local_down ):
308323 object .__setattr__ (self .experts_down , f"weight{ i } " , down_w [i ])
309324
310325 def set_ep_group (self , ep_group : dist .ProcessGroup , ep_mesh : DeviceMesh ) -> None :
0 commit comments