Skip to content

Commit 41267f2

Browse files
committed
fix: Fix memory scrambling bug in reshape_kernel hooks for scanned layers and MoE experts
1 parent c84d94c commit 41267f2

1 file changed

Lines changed: 16 additions & 24 deletions

File tree

src/maxtext/checkpoint_conversion/utils/param_mapping.py

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -505,10 +505,9 @@ def pad_hf_embedding_layer(input_tensor, target_shape):
505505

506506
def reshape_kernel(input_tensor, target_shape):
507507
if saving_to_hf:
508-
flipped_target_shape = np.flip(np.array(target_shape))
509-
return input_tensor.reshape(flipped_target_shape).T
508+
return np.swapaxes(input_tensor, -1, -2).reshape(target_shape)
510509
else:
511-
return input_tensor.T.reshape(target_shape)
510+
return np.swapaxes(input_tensor, -1, -2).reshape(target_shape)
512511

513512
def scale_rmsnorm_layer(input_tensor, target_shape):
514513
if saving_to_hf:
@@ -773,10 +772,9 @@ def pad_embedding_layer(input_tensor, target_shape):
773772
def reshape_kernel(input_tensor, target_shape):
774773
"""Reshapes and transposes kernel weights between MaxText and HF."""
775774
if saving_to_hf:
776-
flipped_target_shape = np.flip(np.array(target_shape))
777-
return input_tensor.reshape(flipped_target_shape).T
775+
return np.swapaxes(input_tensor, -1, -2).reshape(target_shape)
778776
else:
779-
return input_tensor.T.reshape(target_shape)
777+
return np.swapaxes(input_tensor, -1, -2).reshape(target_shape)
780778

781779
def reshape_bias(input_tensor, target_shape=None):
782780
"""Reshapes biases between MaxText 2D (heads, dim) and HF 1D (hidden)."""
@@ -1019,10 +1017,9 @@ def transpose(input_tensor, target_shape=None):
10191017

10201018
def reshape_kernel(input_tensor, target_shape):
10211019
if saving_to_hf:
1022-
flipped_target_shape = np.flip(np.array(target_shape))
1023-
return input_tensor.reshape(flipped_target_shape).T
1020+
return np.swapaxes(input_tensor, -1, -2).reshape(target_shape)
10241021
else:
1025-
return input_tensor.T.reshape(target_shape)
1022+
return np.swapaxes(input_tensor, -1, -2).reshape(target_shape)
10261023

10271024
def permute_conv(input_tensor, target_shape=None):
10281025
# MT: [K, 1, C] <-> HF: [C, 1, K]
@@ -1174,10 +1171,9 @@ def DEEPSEEK_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=Fal
11741171
def reshape_kernel(input_tensor, target_shape):
11751172
"""Reshapes and transposes kernel weights between MaxText and HF."""
11761173
if saving_to_hf:
1177-
flipped_target_shape = np.flip(np.array(target_shape))
1178-
return input_tensor.reshape(flipped_target_shape).T
1174+
return np.swapaxes(input_tensor, -1, -2).reshape(target_shape)
11791175
else:
1180-
return input_tensor.T.reshape(target_shape)
1176+
return np.swapaxes(input_tensor, -1, -2).reshape(target_shape)
11811177

11821178
num_main_layers = config["num_hidden_layers"]
11831179
first_num_dense_layers = config["first_k_dense_replace"]
@@ -1362,10 +1358,9 @@ def transpose(input_tensor, target_shape=None):
13621358
def reshape_kernel(input_tensor, target_shape):
13631359
"""Reshapes and transposes kernel weights between MaxText and HF."""
13641360
if saving_to_hf:
1365-
flipped_target_shape = np.flip(np.array(target_shape))
1366-
return input_tensor.reshape(flipped_target_shape).T
1361+
return np.swapaxes(input_tensor, -1, -2).reshape(target_shape)
13671362
else:
1368-
return input_tensor.T.reshape(target_shape)
1363+
return np.swapaxes(input_tensor, -1, -2).reshape(target_shape)
13691364

13701365
def reshape_bias(input_tensor, target_shape=None):
13711366
"""Reshapes biases between MaxText 2D (heads, dim) and HF 1D (hidden)."""
@@ -1971,10 +1966,9 @@ def adjust_rope(input_tensor, target_shape):
19711966

19721967
def reshape_kernel(input_tensor, target_shape):
19731968
if saving_to_hf:
1974-
flipped_target_shape = np.flip(np.array(target_shape))
1975-
return input_tensor.reshape(flipped_target_shape).transpose()
1969+
return np.swapaxes(input_tensor, -1, -2).reshape(target_shape)
19761970
else:
1977-
return input_tensor.transpose().reshape(target_shape)
1971+
return np.swapaxes(input_tensor, -1, -2).reshape(target_shape)
19781972

19791973
# caveat: hook order does affect result
19801974
# to_huggingface
@@ -2549,10 +2543,9 @@ def pad_hf_embedding_layer(input_tensor, target_shape):
25492543

25502544
def reshape_kernel(input_tensor, target_shape):
25512545
if saving_to_hf:
2552-
flipped_target_shape = np.flip(np.array(target_shape))
2553-
return input_tensor.reshape(flipped_target_shape).T
2546+
return np.swapaxes(input_tensor, -1, -2).reshape(target_shape)
25542547
else:
2555-
return input_tensor.T.reshape(target_shape)
2548+
return np.swapaxes(input_tensor, -1, -2).reshape(target_shape)
25562549

25572550
def scale_rmsnorm_layer(input_tensor, target_shape):
25582551
# Shift of 1.0 is now folded into Gemma 4 text and vision checkpoint weights
@@ -2801,10 +2794,9 @@ def OLMO3_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False,
28012794
# Standard Transpose for Kernels (HF: [Out, In] <-> MaxText: [In, Out])
28022795
def reshape_kernel(input_tensor, target_shape):
28032796
if saving_to_hf:
2804-
flipped_target_shape = np.flip(np.array(target_shape))
2805-
return input_tensor.reshape(flipped_target_shape).T
2797+
return np.swapaxes(input_tensor, -1, -2).reshape(target_shape)
28062798
else:
2807-
return input_tensor.T.reshape(target_shape)
2799+
return np.swapaxes(input_tensor, -1, -2).reshape(target_shape)
28082800

28092801
# Identity mapping for Norms
28102802
# Olmo3 checkpoints typically have weights ~1.0.

0 commit comments

Comments
 (0)