@@ -141,15 +141,17 @@ def get_orthogonal_matrix(size, mode="hadamard", device="cuda"):
141141 raise ValueError (f"Unknown mode { mode } " )
142142
143143
144- def rotate_model (state_dict , layer_idx , moe_num_experts = 48 , hidden_size = 7168 , moe_intermediate_size = 3584 , ep_rank = 0 ):
144+ def rotate_model (
145+ state_dict , prefix_layer_name , layer_idx , moe_num_experts , hidden_size , moe_intermediate_size , ep_rank = 0
146+ ):
145147 with paddle .no_grad ():
146148 # collect hadamard rotation matrix [moe_intermediate_size, moe_intermediate_size]
147149 Q_ffn2 , moe_block_size = get_orthogonal_matrix (size = moe_intermediate_size , mode = "hadamard_ffn2" )
148150 # down_proj.weight: [moe_intermediate_size, hidden_size]
149151 expert_list = [
150152 get_tensor (
151153 state_dict [
152- f"ernie.layers .{ layer_idx } .mlp.experts.{ ep_rank * moe_num_experts + expert_idx } .down_proj.weight"
154+ f"ernie.{ prefix_layer_name } .{ layer_idx } .mlp.experts.{ ep_rank * moe_num_experts + expert_idx } .down_proj.weight"
153155 ]
154156 )
155157 for expert_idx in range (moe_num_experts )
@@ -159,7 +161,7 @@ def rotate_model(state_dict, layer_idx, moe_num_experts=48, hidden_size=7168, mo
159161 for expert_idx in range (moe_num_experts ):
160162 rotated_weight = new_moe_weight [:, expert_idx * hidden_size : (expert_idx + 1 ) * hidden_size ]
161163 expert_idx_local = ep_rank * moe_num_experts + expert_idx
162- state_dict [f"ernie.layers .{ layer_idx } .mlp.experts.{ expert_idx_local } .down_proj.weight" ] = (
164+ state_dict [f"ernie.{ prefix_layer_name } .{ layer_idx } .mlp.experts.{ expert_idx_local } .down_proj.weight" ] = (
163165 rotated_weight .cpu ()
164166 )
165167 del moe_weight , new_moe_weight , rotated_weight
0 commit comments