Skip to content

Commit f4f21e3

Browse files
committed
fix: Checkpoint converter fixes for loading, merging, and recursively updating base and LoRA checkpoints
1 parent b747941 commit f4f21e3

4 files changed

Lines changed: 264 additions & 18 deletions

File tree

src/maxtext/checkpoint_conversion/to_huggingface.py

Lines changed: 53 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -116,12 +116,35 @@ def _get_lora_delta(key, lora_state_dict, lora_scaling):
116116
a_key, b_key = key[7:] + "_lora_a", key[7:] + "_lora_b"
117117

118118
if a_key in lora_state_dict and b_key in lora_state_dict:
119-
data_a, data_b = jnp.asarray(lora_state_dict[a_key], dtype=jnp.float32), jnp.asarray(
120-
lora_state_dict[b_key], dtype=jnp.float32
121-
)
122-
if data_a.ndim > 2:
119+
data_a = jnp.asarray(lora_state_dict[a_key], dtype=jnp.float32)
120+
data_b = jnp.asarray(lora_state_dict[b_key], dtype=jnp.float32)
121+
122+
is_attention = "attention" in key.lower() or "attn" in key.lower()
123+
124+
if is_attention and data_a.ndim > 2:
125+
if data_a.ndim == 4:
126+
# Scanned attention projection: [num_layers, input_dim, heads, rank] & [num_layers, rank, heads, output_dim]
127+
return jnp.einsum("lipr,lrpo->lipo", data_a, data_b) * lora_scaling
128+
# Unscanned attention projection: [input_dim, heads, rank] & [rank, heads, output_dim]
123129
return jnp.einsum("ipr,rpo->ipo", data_a, data_b) * lora_scaling
124-
return jnp.matmul(data_a, data_b) * lora_scaling
130+
else:
131+
if data_a.ndim == 3:
132+
# Scanned standard linear projection: can be [num_layers, input_dim, rank] or [input_dim, num_layers, rank]
133+
rank = data_a.shape[2]
134+
if rank == data_b.shape[1] and rank != data_b.shape[0]:
135+
# Case A: [num_layers, input_dim, rank] & [num_layers, rank, output_dim]
136+
return jnp.einsum("lir,lro->lio", data_a, data_b) * lora_scaling
137+
elif rank == data_b.shape[0] and rank != data_b.shape[1]:
138+
# Case B: [input_dim, num_layers, rank] & [rank, num_layers, output_dim]
139+
return jnp.einsum("ilr,rlo->ilo", data_a, data_b) * lora_scaling
140+
else:
141+
# Disambiguate using key names (Case B is typically 'wo' or 'out-kernel' / 'out_proj')
142+
if any(term in key for term in ["wo", "out-kernel", "out_proj"]):
143+
return jnp.einsum("ilr,rlo->ilo", data_a, data_b) * lora_scaling
144+
else:
145+
return jnp.einsum("lir,lro->lio", data_a, data_b) * lora_scaling
146+
# Unscanned standard linear projection
147+
return jnp.matmul(data_a, data_b) * lora_scaling
125148
return None
126149

127150

@@ -286,19 +309,38 @@ def _transform_weights_to_adapter(param_map, state_dict):
286309
if a_key in state_dict and b_key in state_dict:
287310
data_a, data_b = state_dict[a_key], state_dict[b_key]
288311
hf_paths = [hf_paths] if not isinstance(hf_paths, list) else hf_paths
289-
for i in range(min(data_a.shape[1] if data_a.ndim > 2 else 1, len(hf_paths))):
290-
found_hf_modules.add(hf_paths[i].split(".")[-2])
291-
name = hf_paths[i].replace(".weight", "")
312+
for i, hf_path in enumerate(hf_paths):
313+
found_hf_modules.add(hf_path.split(".")[-2])
314+
name = hf_path.replace(".weight", "")
315+
316+
if data_a.ndim > 2:
317+
if data_a.shape[0] == len(hf_paths):
318+
# Case A: layer dimension is axis 0
319+
layer_a = data_a[i, ...]
320+
layer_b = data_b[i, ...]
321+
else:
322+
# Case B: layer dimension is axis 1
323+
layer_a = data_a[:, i, ...]
324+
layer_b = data_b[:, i, ...]
325+
else:
326+
layer_a = data_a
327+
layer_b = data_b
328+
329+
if layer_a.ndim > 2:
330+
layer_a = layer_a[:, 0, :]
331+
if layer_b.ndim > 2:
332+
layer_b = layer_b[:, 0, :]
333+
292334
processed_params_list.append(
293335
(
294336
f"base_model.model.{name}.lora_A.weight",
295-
jax.numpy.asarray((data_a[:, i, :] if data_a.ndim > 2 else data_a).T),
337+
jax.numpy.asarray(layer_a.T),
296338
)
297339
)
298340
processed_params_list.append(
299341
(
300342
f"base_model.model.{name}.lora_B.weight",
301-
jax.numpy.asarray((data_b[:, i, :] if data_b.ndim > 2 else data_b).T),
343+
jax.numpy.asarray(layer_b.T),
302344
)
303345
)
304346
return dict(processed_params_list), found_hf_modules
@@ -424,9 +466,7 @@ def main(argv: Sequence[str]) -> None:
424466
maxtext_state_dict = detect_and_extract_checkpoint(checkpoint_dict)
425467

426468
# Validate that checkpoint keys match the parameter mapping
427-
state_keys = set(maxtext_state_dict) | {
428-
k.replace("_lora_a", "").replace("_lora_b", "") for k in maxtext_state_dict if "_lora_" in k
429-
}
469+
state_keys = {k.replace("_lora_a", "").replace("_lora_b", "") for k in maxtext_state_dict}
430470
filtered_map_keys = validate_and_filter_param_map_keys(param_map, state_keys)
431471

432472
# When not converting a multimodal model, skip vision encoder weights even if

src/maxtext/checkpoint_conversion/utils/utils.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -817,6 +817,16 @@ def format_meter(
817817
return super().format_meter(n=n, total=total, elapsed=elapsed, postfix=postfix, **extra_kwargs)
818818

819819

820+
def _recursive_update(d: dict, u: dict) -> dict:
821+
"""Recursively updates dictionary d with dictionary u in place."""
822+
for k, v in u.items():
823+
if isinstance(v, dict) and isinstance(d.get(k), dict):
824+
_recursive_update(d[k], v)
825+
else:
826+
d[k] = v
827+
return d
828+
829+
820830
def load_orbax_checkpoint(config) -> dict:
821831
"""Loads Orbax checkpoints from Base and/or LoRA paths in config.
822832
@@ -852,15 +862,21 @@ def create_restore_args(tree_metadata):
852862
paths = [p for p in [config.load_parameters_path, lora_path] if p]
853863

854864
merged_dict = {}
855-
for path in paths:
865+
for i, path in enumerate(paths):
856866
checkpoint_path = epath.Path(path)
857867
metadata = ckptr.metadata(checkpoint_path)
858868
restore_args = jax.tree_util.tree_map(
859869
lambda x: create_restore_args(x) if hasattr(x, "shape") else None,
860870
metadata.item_metadata.tree,
861871
is_leaf=lambda x: hasattr(x, "shape"),
862872
)
863-
merged_dict.update(ckptr.restore(checkpoint_path, restore_args=restore_args))
873+
restored = ckptr.restore(checkpoint_path, restore_args=restore_args)
874+
875+
if i == 0:
876+
merged_dict = restored
877+
else:
878+
# Recursively update base checkpoint with LoRA adapter checkpoint keys to avoid overwriting
879+
_recursive_update(merged_dict, restored)
864880

865881
return merged_dict
866882

@@ -903,7 +919,7 @@ def extract_nnx_weights(weights_dict: dict) -> dict[str, np.ndarray]:
903919
result = {}
904920
leaves_with_paths = jax.tree_util.tree_leaves_with_path(weights_dict)
905921
for path_tuple, leaf_value in leaves_with_paths:
906-
path_keys = [k.key for k in path_tuple]
922+
path_keys = [str(k.key) for k in path_tuple]
907923
# Skip NNX RNG state variables (not model weights)
908924
if "to_nnx__rngs" in path_keys or any(k.endswith("_rngs") for k in path_keys):
909925
continue

src/maxtext/configs/post_train/lora_module_path.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ llama3.1: "decoder/layers/.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1
1919
qwen3: "decoder/layers/self_attention/(query|key|value|out)|decoder/layers/mlp/(wi_0|wi_1|wo)"
2020
mistral: "decoder/layers/.*(attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))"
2121
deepseek2: "decoder/(dense_layers|moe_stack)/self_attention/(query|out|wkv_a|wkv_b)|decoder/(dense_layers|moe_stack)/(mlp|shared_experts)/(wi_0|wi_1|wo)"
22-
gemma2: "decoder/layers/(self_attention_local|self_attention_global)/(query|key|value|out)|decoder/layers/(mlp_local|mlp_global)/(wi_0|wi_1|wo)"
23-
gemma3: "decoder/layers/.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo|gate|up|down))"
22+
gemma2: "decoder/(scanned_blocks|layers_remainder|layers)/(self_attention_local|self_attention_global)/(query|key|value|out)|decoder/(scanned_blocks|layers_remainder|layers)/(mlp_local|mlp_global)/(wi_0|wi_1|wo)"
23+
gemma3: "decoder/(scanned_blocks|layers_remainder|layers)/.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo|gate|up|down))"
2424
olmo3: "decoder/layers/.*(attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))"
2525
gpt3: "decoder/layers/(self_attention/(qkv_proj|out)|mlp/(wi|wo))"
2626

tests/unit/hf_checkpoint_conversion_test.py

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@
2828
convert_hf_lora_key_to_maxtext,
2929
_process_and_stack_weights,
3030
)
31+
from maxtext.checkpoint_conversion.utils.utils import (
32+
_recursive_update,
33+
load_orbax_checkpoint,
34+
)
3135

3236

3337
class HFCheckpointConversionTest(unittest.TestCase):
@@ -105,6 +109,48 @@ def test_transform_weights_to_adapter(self):
105109
self.assertEqual(weights["base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight"].shape, (20, 4))
106110
self.assertIn("q_proj", modules)
107111

112+
# 1. Scanned standard linear Case A (3D): [num_layers, input_dim, rank] & [num_layers, rank, output_dim]
113+
param_map_scanned_a = {
114+
"params-decoder-scanned_blocks-mlp-wi_0-kernel": [
115+
"model.layers.0.mlp.gate_proj.weight",
116+
"model.layers.1.mlp.gate_proj.weight",
117+
]
118+
}
119+
# num_layers = 2, input_dim = 10, rank = 4, output_dim = 20
120+
data_a_scanned_a = np.ones((2, 10, 4), dtype=np.float32) * 0.5
121+
data_b_scanned_a = np.ones((2, 4, 20), dtype=np.float32) * 0.5
122+
lora_dict_scanned_a = {
123+
"params-decoder-scanned_blocks-mlp-wi_0-kernel_lora_a": data_a_scanned_a,
124+
"params-decoder-scanned_blocks-mlp-wi_0-kernel_lora_b": data_b_scanned_a,
125+
}
126+
weights_sa, _ = _transform_weights_to_adapter(param_map_scanned_a, lora_dict_scanned_a)
127+
self.assertIn("base_model.model.model.layers.0.mlp.gate_proj.lora_A.weight", weights_sa)
128+
self.assertIn("base_model.model.model.layers.0.mlp.gate_proj.lora_B.weight", weights_sa)
129+
# Since layer dimension is axis 0, layer 0 is data_a_scanned_a[0, :, :], which has shape (10, 4), transpose -> (4, 10)
130+
self.assertEqual(weights_sa["base_model.model.model.layers.0.mlp.gate_proj.lora_A.weight"].shape, (4, 10))
131+
self.assertEqual(weights_sa["base_model.model.model.layers.0.mlp.gate_proj.lora_B.weight"].shape, (20, 4))
132+
133+
# 2. Scanned standard linear Case B (3D): [input_dim, num_layers, rank] & [rank, num_layers, output_dim]
134+
param_map_scanned_b = {
135+
"params-decoder-scanned_blocks-mlp-wo-kernel": [
136+
"model.layers.0.mlp.down_proj.weight",
137+
"model.layers.1.mlp.down_proj.weight",
138+
]
139+
}
140+
# num_layers = 2, input_dim = 10, rank = 4, output_dim = 20
141+
data_a_scanned_b = np.ones((10, 2, 4), dtype=np.float32) * 0.5
142+
data_b_scanned_b = np.ones((4, 2, 20), dtype=np.float32) * 0.5
143+
lora_dict_scanned_b = {
144+
"params-decoder-scanned_blocks-mlp-wo-kernel_lora_a": data_a_scanned_b,
145+
"params-decoder-scanned_blocks-mlp-wo-kernel_lora_b": data_b_scanned_b,
146+
}
147+
weights_sb, _ = _transform_weights_to_adapter(param_map_scanned_b, lora_dict_scanned_b)
148+
self.assertIn("base_model.model.model.layers.0.mlp.down_proj.lora_A.weight", weights_sb)
149+
self.assertIn("base_model.model.model.layers.0.mlp.down_proj.lora_B.weight", weights_sb)
150+
# Since layer dimension is axis 1, layer 0 is data_a_scanned_b[:, 0, :], which has shape (10, 4), transpose -> (4, 10)
151+
self.assertEqual(weights_sb["base_model.model.model.layers.0.mlp.down_proj.lora_A.weight"].shape, (4, 10))
152+
self.assertEqual(weights_sb["base_model.model.model.layers.0.mlp.down_proj.lora_B.weight"].shape, (20, 4))
153+
108154
def test_transform_weights_to_full_model_merged(self):
109155
config = MagicMock()
110156
config.lora.lora_alpha = 32.0
@@ -125,6 +171,49 @@ def test_transform_weights_to_full_model_merged(self):
125171
self.assertIn("model.layers.0.self_attn.q_proj.weight", weights)
126172
self.assertTrue(np.allclose(weights["model.layers.0.self_attn.q_proj.weight"], self.expected_merged_val))
127173

174+
def test_get_lora_delta_scanned_and_unscanned_variants(self):
175+
cases = [
176+
# (name, key, shape_a, shape_b, expected_shape, expected_val)
177+
("2d_linear", "params-decoder-layers-layers_0-mlp-wi_0-kernel", (10, 4), (4, 20), (10, 20), 2.0),
178+
(
179+
"3d_unscanned_attn",
180+
"params-decoder-layers-layers_0-self_attention-query-kernel",
181+
(10, 2, 4),
182+
(4, 2, 20),
183+
(10, 2, 20),
184+
2.0,
185+
),
186+
(
187+
"3d_scanned_linear_a",
188+
"params-decoder-scanned_blocks-mlp-wi_0-kernel",
189+
(3, 10, 4),
190+
(3, 4, 20),
191+
(3, 10, 20),
192+
2.0,
193+
),
194+
("3d_scanned_linear_b", "params-decoder-scanned_blocks-mlp-wo-kernel", (10, 3, 4), (4, 3, 20), (10, 3, 20), 2.0),
195+
(
196+
"4d_scanned_attn",
197+
"params-decoder-scanned_blocks-self_attention-query-kernel",
198+
(3, 10, 2, 4),
199+
(3, 4, 2, 20),
200+
(3, 10, 2, 20),
201+
2.0,
202+
),
203+
("edge_case_a", "params-decoder-scanned_blocks-mlp-wi_0-kernel", (3, 3, 3), (3, 3, 20), (3, 3, 20), 1.5),
204+
("edge_case_b", "params-decoder-scanned_blocks-mlp-wo-kernel", (3, 3, 3), (3, 3, 20), (3, 3, 20), 1.5),
205+
]
206+
207+
for name, key, shape_a, shape_b, expected_shape, expected_val in cases:
208+
with self.subTest(name=name):
209+
state_dict = {
210+
f"{key}_lora_a": np.ones(shape_a, dtype=np.float32) * 0.5,
211+
f"{key}_lora_b": np.ones(shape_b, dtype=np.float32) * 0.5,
212+
}
213+
delta = _get_lora_delta(key, state_dict, 2.0)
214+
self.assertEqual(delta.shape, expected_shape)
215+
self.assertTrue(np.allclose(delta, expected_val))
216+
128217

129218
class HFToMaxTextLoRAConversionTest(unittest.TestCase):
130219
"""Tests the conversion logic in to_maxtext with LoRA support."""
@@ -169,5 +258,106 @@ def test_process_and_stack_weights(self):
169258
self.assertEqual(stacked[1, 0, 0], 2.0)
170259

171260

261+
class CheckpointMergingTest(unittest.TestCase):
262+
"""Tests the recursive_update and load_orbax_checkpoint functions to ensure we don't overwrite weights."""
263+
264+
def test_recursive_update(self):
265+
266+
base = {
267+
"params": {
268+
"decoder": {
269+
"layers": {
270+
"kernel": np.ones((4, 4)),
271+
}
272+
}
273+
}
274+
}
275+
lora = {
276+
"params": {
277+
"decoder": {
278+
"layers": {
279+
"kernel_lora_a": np.ones((4, 2)),
280+
"kernel_lora_b": np.ones((2, 4)),
281+
}
282+
}
283+
}
284+
}
285+
286+
merged = {}
287+
_recursive_update(merged, base)
288+
_recursive_update(merged, lora)
289+
290+
# Verify that both base and lora weights are present and not overwritten
291+
self.assertIn("kernel", merged["params"]["decoder"]["layers"])
292+
self.assertIn("kernel_lora_a", merged["params"]["decoder"]["layers"])
293+
self.assertIn("kernel_lora_b", merged["params"]["decoder"]["layers"])
294+
np.testing.assert_array_equal(merged["params"]["decoder"]["layers"]["kernel"], np.ones((4, 4)))
295+
np.testing.assert_array_equal(merged["params"]["decoder"]["layers"]["kernel_lora_a"], np.ones((4, 2)))
296+
np.testing.assert_array_equal(merged["params"]["decoder"]["layers"]["kernel_lora_b"], np.ones((2, 4)))
297+
298+
@unittest.mock.patch("maxtext.checkpoint_conversion.utils.utils.ocp.Checkpointer")
299+
@unittest.mock.patch("maxtext.checkpoint_conversion.utils.utils.epath.Path")
300+
@unittest.mock.patch("maxtext.checkpoint_conversion.utils.utils.jax.devices")
301+
def test_load_orbax_checkpoint_recursive_merge(self, mock_jax_devices, mock_path, mock_checkpointer_cls):
302+
303+
# Mock jax devices
304+
mock_jax_devices.return_value = [MagicMock()]
305+
306+
# Mock Orbax Checkpointer and its restore results
307+
mock_ckptr = MagicMock()
308+
mock_checkpointer_cls.return_value = mock_ckptr
309+
310+
# Base checkpoint metadata and content
311+
base_metadata = MagicMock()
312+
base_metadata.item_metadata.tree = {"params": {"decoder": {"layers": {"kernel": MagicMock(shape=(4, 4))}}}}
313+
base_restore_content = {"params": {"decoder": {"layers": {"kernel": np.ones((4, 4))}}}}
314+
315+
# LoRA checkpoint metadata and content
316+
lora_metadata = MagicMock()
317+
lora_metadata.item_metadata.tree = {
318+
"params": {
319+
"decoder": {
320+
"layers": {
321+
"kernel_lora_a": MagicMock(shape=(4, 2)),
322+
"kernel_lora_b": MagicMock(shape=(2, 4)),
323+
}
324+
}
325+
}
326+
}
327+
lora_restore_content = {
328+
"params": {
329+
"decoder": {
330+
"layers": {
331+
"kernel_lora_a": np.ones((4, 2)),
332+
"kernel_lora_b": np.ones((2, 4)),
333+
}
334+
}
335+
}
336+
}
337+
338+
# Mock metadata and restore calls
339+
mock_ckptr.metadata.side_effect = [base_metadata, lora_metadata]
340+
mock_ckptr.restore.side_effect = [base_restore_content, lora_restore_content]
341+
342+
# Create dummy config
343+
config = MagicMock()
344+
config.checkpoint_storage_concurrent_gb = 8
345+
config.checkpoint_storage_use_ocdbt = True
346+
config.checkpoint_storage_use_zarr3 = True
347+
config.load_parameters_path = "gs://base-bucket/checkpoints"
348+
config.lora.lora_restore_path = "gs://lora-bucket/checkpoints"
349+
350+
# Load and merge
351+
merged = load_orbax_checkpoint(config)
352+
353+
# Assert checkpointer was called twice and restored both
354+
self.assertEqual(mock_ckptr.restore.call_count, 2)
355+
356+
# Verify that the keys are recursively merged correctly!
357+
self.assertIn("kernel", merged["params"]["decoder"]["layers"])
358+
self.assertIn("kernel_lora_a", merged["params"]["decoder"]["layers"])
359+
self.assertIn("kernel_lora_b", merged["params"]["decoder"]["layers"])
360+
361+
172362
if __name__ == "__main__":
173363
unittest.main()

0 commit comments

Comments
 (0)