@@ -176,17 +176,24 @@ def weight_loader(self, param, loaded_weight, expert_id, shard_id: Optional[str]
176176
177177 if shard_id is None :
178178 # 1.gate up fused in disk
179+ model_format = getattr (param , "model_format" , "" )
180+ is_torch_model = model_format == "torch"
179181 output_size = param [expert_id - self .expert_id_offset ].shape [SHARD_ID_TO_SHARDED_DIM ["gate" ]]
180- shard_offsets = [
181- # (shard_id, shard_offset, shard_size)
182- ("gate" , 0 , output_size // 2 * self .tp_size ),
183- ("up" , output_size // 2 * self .tp_size , output_size // 2 * self .tp_size ),
184- ]
185- for shard_id , shard_offset , shard_size in shard_offsets :
186- loaded_weight_shard = slice_fn (
187- loaded_weight , SHARD_ID_TO_SHARDED_DIM [shard_id ], shard_offset , shard_offset + shard_size
188- )
189- self .weight_loader (param , loaded_weight_shard , expert_id , shard_id )
182+ per_rank = output_size // 2
183+ start = self .tp_rank * per_rank
184+ loaded_weight_shard_gate = slice_fn (
185+ loaded_weight , is_torch_model ^ SHARD_ID_TO_SHARDED_DIM ["gate" ], start , start + per_rank
186+ )
187+ self ._load_gate_up_weight (
188+ param , expert_id , loaded_weight_shard_gate , "gate" , SHARD_ID_TO_SHARDED_DIM ["gate" ], is_sharded = True
189+ )
190+ start_up = output_size // 2 * self .tp_size + self .tp_rank * per_rank
191+ loaded_weight_shard_up = slice_fn (
192+ loaded_weight , is_torch_model ^ SHARD_ID_TO_SHARDED_DIM ["up" ], start_up , start_up + per_rank
193+ )
194+ self ._load_gate_up_weight (
195+ param , expert_id , loaded_weight_shard_up , "up" , SHARD_ID_TO_SHARDED_DIM ["up" ], is_sharded = True
196+ )
190197 else :
191198 # 2.gate up splited in disk
192199 assert shard_id in ["gate" , "down" , "up" ]
@@ -198,22 +205,23 @@ def weight_loader(self, param, loaded_weight, expert_id, shard_id: Optional[str]
198205 shard_dim = SHARD_ID_TO_SHARDED_DIM [shard_id ],
199206 )
200207
201- def _load_gate_up_weight (self , param , expert_id , loaded_weight , shard_id , shard_dim = None ):
208+ def _load_gate_up_weight (self , param , expert_id , loaded_weight , shard_id , shard_dim = None , is_sharded = False ):
202209 model_format = getattr (param , "model_format" , "" )
203- if model_format == "torch" :
204- loaded_weight = loaded_weight . transpose ([ 1 , 0 ])
205- dim = - 1 if shard_dim else 0
206- if self . tp_size > 1 :
210+ is_torch_model = model_format == "torch"
211+ if self . tp_size > 1 and not is_sharded :
212+ tp_shard_dim = is_torch_model ^ shard_dim
213+ weight_dim = - 1 if tp_shard_dim else 0
207214 if isinstance (loaded_weight , (np .ndarray , paddle .Tensor )):
208- size = loaded_weight .shape [dim ]
215+ size = loaded_weight .shape [weight_dim ]
209216 else :
210- size = loaded_weight .get_shape ()[dim ]
217+ size = loaded_weight .get_shape ()[weight_dim ]
211218 block_size = size // self .tp_size
212219 shard_offset = self .tp_rank * block_size
213220 shard_size = (self .tp_rank + 1 ) * block_size
214- loaded_weight = slice_fn (loaded_weight , shard_dim , shard_offset , shard_size )
215-
221+ loaded_weight = slice_fn (loaded_weight , tp_shard_dim , shard_offset , shard_size )
222+ loaded_weight = get_tensor ( loaded_weight )
216223 expert_param = param [expert_id - self .expert_id_offset ]
224+ dim = - 1 if shard_dim else 0
217225 param_shard_size = expert_param .shape [dim ] // 2
218226 if shard_id == "gate" :
219227 param_shard_offset = 0
@@ -232,36 +240,35 @@ def _load_gate_up_weight(self, param, expert_id, loaded_weight, shard_id, shard_
232240 )
233241
234242 # To ensure compatibility across backends, apply an extra transpose for GCU and XPU
235- if current_platform .is_xpu () or current_platform .is_gcu ():
236- if expert_param .shape != loaded_weight .shape :
237- loaded_weight = loaded_weight .transpose ([1 , 0 ])
243+ if expert_param .shape != loaded_weight .shape :
244+ loaded_weight = loaded_weight .transpose ([1 , 0 ])
238245 assert expert_param .shape == loaded_weight .shape , (
239246 f"Attempted to load weight ({ loaded_weight .shape } ) " f"into parameter ({ expert_param .shape } )"
240247 )
241248 expert_param .copy_ (loaded_weight , False )
242249
243250 def _load_down_weight (self , param , expert_id , loaded_weight , shard_id , shard_dim = None ):
244251 model_format = getattr (param , "model_format" , "" )
245- if model_format == "torch" :
246- loaded_weight = loaded_weight .transpose ([1 , 0 ])
252+ is_torch_model = model_format == "torch"
247253 if self .tp_size > 1 and shard_dim is not None :
248- dim = - 1 if shard_dim else 0
249- if isinstance (loaded_weight , (np .ndarray , paddle .Tensor )):
254+ tp_shard_dim = is_torch_model ^ shard_dim
255+ dim = - 1 if tp_shard_dim else 0
256+ if isinstance (loaded_weight , paddle .Tensor ):
250257 size = loaded_weight .shape [dim ]
251258 else :
252259 size = loaded_weight .get_shape ()[dim ]
253260 block_size = size // self .tp_size
254261 shard_offset = self .tp_rank * block_size
255262 shard_size = (self .tp_rank + 1 ) * block_size
256- loaded_weight = slice_fn (loaded_weight , shard_dim , shard_offset , shard_size )
263+ loaded_weight = slice_fn (loaded_weight , tp_shard_dim , shard_offset , shard_size )
264+ loaded_weight = get_tensor (loaded_weight )
257265 expert_param = param [expert_id - self .expert_id_offset ]
258266 if hasattr (param , "tensor_track" ):
259267 # for dyn quant
260268 param .tensor_track .mark (start = 0 , batch_id = expert_id - self .expert_id_offset )
261- # To ensure compatibility across backends, apply an extra transpose for GCU and XPU
262- if current_platform .is_xpu or current_platform .is_gcu ():
263- if expert_param .shape != loaded_weight .shape :
264- loaded_weight = loaded_weight .transpose ([1 , 0 ])
269+ # To ensure compatibility across backends, apply an extra transpose for GCU and XPU and opensource weight
270+ if expert_param .shape != loaded_weight .shape :
271+ loaded_weight = loaded_weight .transpose ([1 , 0 ])
265272 assert expert_param .shape == loaded_weight .shape , (
266273 f"Attempted to load weight ({ loaded_weight .shape } ) " f"into parameter ({ expert_param .shape } )"
267274 )
0 commit comments