@@ -505,10 +505,9 @@ def pad_hf_embedding_layer(input_tensor, target_shape):
505505
506506 def reshape_kernel (input_tensor , target_shape ):
507507 if saving_to_hf :
508- flipped_target_shape = np .flip (np .array (target_shape ))
509- return input_tensor .reshape (flipped_target_shape ).T
508+ return np .swapaxes (input_tensor , - 1 , - 2 ).reshape (target_shape )
510509 else :
511- return input_tensor . T .reshape (target_shape )
510+ return np . swapaxes ( input_tensor , - 1 , - 2 ) .reshape (target_shape )
512511
513512 def scale_rmsnorm_layer (input_tensor , target_shape ):
514513 if saving_to_hf :
@@ -773,10 +772,9 @@ def pad_embedding_layer(input_tensor, target_shape):
773772 def reshape_kernel (input_tensor , target_shape ):
774773 """Reshapes and transposes kernel weights between MaxText and HF."""
775774 if saving_to_hf :
776- flipped_target_shape = np .flip (np .array (target_shape ))
777- return input_tensor .reshape (flipped_target_shape ).T
775+ return np .swapaxes (input_tensor , - 1 , - 2 ).reshape (target_shape )
778776 else :
779- return input_tensor . T .reshape (target_shape )
777+ return np . swapaxes ( input_tensor , - 1 , - 2 ) .reshape (target_shape )
780778
781779 def reshape_bias (input_tensor , target_shape = None ):
782780 """Reshapes biases between MaxText 2D (heads, dim) and HF 1D (hidden)."""
@@ -1019,10 +1017,9 @@ def transpose(input_tensor, target_shape=None):
10191017
10201018 def reshape_kernel (input_tensor , target_shape ):
10211019 if saving_to_hf :
1022- flipped_target_shape = np .flip (np .array (target_shape ))
1023- return input_tensor .reshape (flipped_target_shape ).T
1020+ return np .swapaxes (input_tensor , - 1 , - 2 ).reshape (target_shape )
10241021 else :
1025- return input_tensor . T .reshape (target_shape )
1022+ return np . swapaxes ( input_tensor , - 1 , - 2 ) .reshape (target_shape )
10261023
10271024 def permute_conv (input_tensor , target_shape = None ):
10281025 # MT: [K, 1, C] <-> HF: [C, 1, K]
@@ -1174,10 +1171,9 @@ def DEEPSEEK_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=Fal
11741171 def reshape_kernel (input_tensor , target_shape ):
11751172 """Reshapes and transposes kernel weights between MaxText and HF."""
11761173 if saving_to_hf :
1177- flipped_target_shape = np .flip (np .array (target_shape ))
1178- return input_tensor .reshape (flipped_target_shape ).T
1174+ return np .swapaxes (input_tensor , - 1 , - 2 ).reshape (target_shape )
11791175 else :
1180- return input_tensor . T .reshape (target_shape )
1176+ return np . swapaxes ( input_tensor , - 1 , - 2 ) .reshape (target_shape )
11811177
11821178 num_main_layers = config ["num_hidden_layers" ]
11831179 first_num_dense_layers = config ["first_k_dense_replace" ]
@@ -1362,10 +1358,9 @@ def transpose(input_tensor, target_shape=None):
13621358 def reshape_kernel (input_tensor , target_shape ):
13631359 """Reshapes and transposes kernel weights between MaxText and HF."""
13641360 if saving_to_hf :
1365- flipped_target_shape = np .flip (np .array (target_shape ))
1366- return input_tensor .reshape (flipped_target_shape ).T
1361+ return np .swapaxes (input_tensor , - 1 , - 2 ).reshape (target_shape )
13671362 else :
1368- return input_tensor . T .reshape (target_shape )
1363+ return np . swapaxes ( input_tensor , - 1 , - 2 ) .reshape (target_shape )
13691364
13701365 def reshape_bias (input_tensor , target_shape = None ):
13711366 """Reshapes biases between MaxText 2D (heads, dim) and HF 1D (hidden)."""
@@ -1971,10 +1966,9 @@ def adjust_rope(input_tensor, target_shape):
19711966
19721967 def reshape_kernel (input_tensor , target_shape ):
19731968 if saving_to_hf :
1974- flipped_target_shape = np .flip (np .array (target_shape ))
1975- return input_tensor .reshape (flipped_target_shape ).transpose ()
1969+ return np .swapaxes (input_tensor , - 1 , - 2 ).reshape (target_shape )
19761970 else :
1977- return input_tensor . transpose ( ).reshape (target_shape )
1971+ return np . swapaxes ( input_tensor , - 1 , - 2 ).reshape (target_shape )
19781972
19791973 # caveat: hook order does affect result
19801974 # to_huggingface
@@ -2549,10 +2543,9 @@ def pad_hf_embedding_layer(input_tensor, target_shape):
25492543
25502544 def reshape_kernel (input_tensor , target_shape ):
25512545 if saving_to_hf :
2552- flipped_target_shape = np .flip (np .array (target_shape ))
2553- return input_tensor .reshape (flipped_target_shape ).T
2546+ return np .swapaxes (input_tensor , - 1 , - 2 ).reshape (target_shape )
25542547 else :
2555- return input_tensor . T .reshape (target_shape )
2548+ return np . swapaxes ( input_tensor , - 1 , - 2 ) .reshape (target_shape )
25562549
25572550 def scale_rmsnorm_layer (input_tensor , target_shape ):
25582551 # Shift of 1.0 is now folded into Gemma 4 text and vision checkpoint weights
@@ -2801,10 +2794,9 @@ def OLMO3_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False,
28012794 # Standard Transpose for Kernels (HF: [Out, In] <-> MaxText: [In, Out])
28022795 def reshape_kernel (input_tensor , target_shape ):
28032796 if saving_to_hf :
2804- flipped_target_shape = np .flip (np .array (target_shape ))
2805- return input_tensor .reshape (flipped_target_shape ).T
2797+ return np .swapaxes (input_tensor , - 1 , - 2 ).reshape (target_shape )
28062798 else :
2807- return input_tensor . T .reshape (target_shape )
2799+ return np . swapaxes ( input_tensor , - 1 , - 2 ) .reshape (target_shape )
28082800
28092801 # Identity mapping for Norms
28102802 # Olmo3 checkpoints typically have weights ~1.0.
0 commit comments