Skip to content

Commit ad7fe65

Browse files
committed
fix: Checkpoint converter fixes for loading, merging, and recursively updating base and LoRA checkpoints
1 parent b747941 commit ad7fe65

4 files changed

Lines changed: 337 additions & 18 deletions

File tree

src/maxtext/checkpoint_conversion/to_huggingface.py

Lines changed: 53 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -116,12 +116,35 @@ def _get_lora_delta(key, lora_state_dict, lora_scaling):
116116
a_key, b_key = key[7:] + "_lora_a", key[7:] + "_lora_b"
117117

118118
if a_key in lora_state_dict and b_key in lora_state_dict:
119-
data_a, data_b = jnp.asarray(lora_state_dict[a_key], dtype=jnp.float32), jnp.asarray(
120-
lora_state_dict[b_key], dtype=jnp.float32
121-
)
122-
if data_a.ndim > 2:
119+
data_a = jnp.asarray(lora_state_dict[a_key], dtype=jnp.float32)
120+
data_b = jnp.asarray(lora_state_dict[b_key], dtype=jnp.float32)
121+
122+
is_attention = "attention" in key.lower() or "attn" in key.lower()
123+
124+
if is_attention and data_a.ndim > 2:
125+
if data_a.ndim == 4:
126+
# Scanned attention projection: [num_layers, input_dim, heads, rank] & [num_layers, rank, heads, output_dim]
127+
return jnp.einsum("lipr,lrpo->lipo", data_a, data_b) * lora_scaling
128+
# Unscanned attention projection: [input_dim, heads, rank] & [rank, heads, output_dim]
123129
return jnp.einsum("ipr,rpo->ipo", data_a, data_b) * lora_scaling
124-
return jnp.matmul(data_a, data_b) * lora_scaling
130+
else:
131+
if data_a.ndim == 3:
132+
# Scanned standard linear projection: can be [num_layers, input_dim, rank] or [input_dim, num_layers, rank]
133+
rank = data_a.shape[2]
134+
if rank == data_b.shape[1] and rank != data_b.shape[0]:
135+
# Case A: [num_layers, input_dim, rank] & [num_layers, rank, output_dim]
136+
return jnp.einsum("lir,lro->lio", data_a, data_b) * lora_scaling
137+
elif rank == data_b.shape[0] and rank != data_b.shape[1]:
138+
# Case B: [input_dim, num_layers, rank] & [rank, num_layers, output_dim]
139+
return jnp.einsum("ilr,rlo->ilo", data_a, data_b) * lora_scaling
140+
else:
141+
# Disambiguate using key names (Case B is typically 'wo' or 'out-kernel' / 'out_proj')
142+
if any(term in key for term in ["wo", "out-kernel", "out_proj"]):
143+
return jnp.einsum("ilr,rlo->ilo", data_a, data_b) * lora_scaling
144+
else:
145+
return jnp.einsum("lir,lro->lio", data_a, data_b) * lora_scaling
146+
# Unscanned standard linear projection
147+
return jnp.matmul(data_a, data_b) * lora_scaling
125148
return None
126149

127150

@@ -286,19 +309,38 @@ def _transform_weights_to_adapter(param_map, state_dict):
286309
if a_key in state_dict and b_key in state_dict:
287310
data_a, data_b = state_dict[a_key], state_dict[b_key]
288311
hf_paths = [hf_paths] if not isinstance(hf_paths, list) else hf_paths
289-
for i in range(min(data_a.shape[1] if data_a.ndim > 2 else 1, len(hf_paths))):
290-
found_hf_modules.add(hf_paths[i].split(".")[-2])
291-
name = hf_paths[i].replace(".weight", "")
312+
for i, hf_path in enumerate(hf_paths):
313+
found_hf_modules.add(hf_path.split(".")[-2])
314+
name = hf_path.replace(".weight", "")
315+
316+
if data_a.ndim > 2:
317+
if data_a.shape[0] == len(hf_paths):
318+
# Case A: layer dimension is axis 0
319+
layer_a = data_a[i, ...]
320+
layer_b = data_b[i, ...]
321+
else:
322+
# Case B: layer dimension is axis 1
323+
layer_a = data_a[:, i, ...]
324+
layer_b = data_b[:, i, ...]
325+
else:
326+
layer_a = data_a
327+
layer_b = data_b
328+
329+
if layer_a.ndim > 2:
330+
layer_a = layer_a[:, 0, :]
331+
if layer_b.ndim > 2:
332+
layer_b = layer_b[:, 0, :]
333+
292334
processed_params_list.append(
293335
(
294336
f"base_model.model.{name}.lora_A.weight",
295-
jax.numpy.asarray((data_a[:, i, :] if data_a.ndim > 2 else data_a).T),
337+
jax.numpy.asarray(layer_a.T),
296338
)
297339
)
298340
processed_params_list.append(
299341
(
300342
f"base_model.model.{name}.lora_B.weight",
301-
jax.numpy.asarray((data_b[:, i, :] if data_b.ndim > 2 else data_b).T),
343+
jax.numpy.asarray(layer_b.T),
302344
)
303345
)
304346
return dict(processed_params_list), found_hf_modules
@@ -424,9 +466,7 @@ def main(argv: Sequence[str]) -> None:
424466
maxtext_state_dict = detect_and_extract_checkpoint(checkpoint_dict)
425467

426468
# Validate that checkpoint keys match the parameter mapping
427-
state_keys = set(maxtext_state_dict) | {
428-
k.replace("_lora_a", "").replace("_lora_b", "") for k in maxtext_state_dict if "_lora_" in k
429-
}
469+
state_keys = {k.replace("_lora_a", "").replace("_lora_b", "") for k in maxtext_state_dict}
430470
filtered_map_keys = validate_and_filter_param_map_keys(param_map, state_keys)
431471

432472
# When not converting a multimodal model, skip vision encoder weights even if

src/maxtext/checkpoint_conversion/utils/utils.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -817,6 +817,16 @@ def format_meter(
817817
return super().format_meter(n=n, total=total, elapsed=elapsed, postfix=postfix, **extra_kwargs)
818818

819819

820+
def _recursive_update(d: dict, u: dict) -> dict:
821+
"""Recursively updates dictionary d with dictionary u in place."""
822+
for k, v in u.items():
823+
if isinstance(v, dict) and isinstance(d.get(k), dict):
824+
_recursive_update(d[k], v)
825+
else:
826+
d[k] = v
827+
return d
828+
829+
820830
def load_orbax_checkpoint(config) -> dict:
821831
"""Loads Orbax checkpoints from Base and/or LoRA paths in config.
822832
@@ -852,15 +862,21 @@ def create_restore_args(tree_metadata):
852862
paths = [p for p in [config.load_parameters_path, lora_path] if p]
853863

854864
merged_dict = {}
855-
for path in paths:
865+
for i, path in enumerate(paths):
856866
checkpoint_path = epath.Path(path)
857867
metadata = ckptr.metadata(checkpoint_path)
858868
restore_args = jax.tree_util.tree_map(
859869
lambda x: create_restore_args(x) if hasattr(x, "shape") else None,
860870
metadata.item_metadata.tree,
861871
is_leaf=lambda x: hasattr(x, "shape"),
862872
)
863-
merged_dict.update(ckptr.restore(checkpoint_path, restore_args=restore_args))
873+
restored = ckptr.restore(checkpoint_path, restore_args=restore_args)
874+
875+
if i == 0:
876+
merged_dict = restored
877+
else:
878+
# Recursively update base checkpoint with LoRA adapter checkpoint keys to avoid overwriting
879+
_recursive_update(merged_dict, restored)
864880

865881
return merged_dict
866882

@@ -903,7 +919,7 @@ def extract_nnx_weights(weights_dict: dict) -> dict[str, np.ndarray]:
903919
result = {}
904920
leaves_with_paths = jax.tree_util.tree_leaves_with_path(weights_dict)
905921
for path_tuple, leaf_value in leaves_with_paths:
906-
path_keys = [k.key for k in path_tuple]
922+
path_keys = [str(k.key) for k in path_tuple]
907923
# Skip NNX RNG state variables (not model weights)
908924
if "to_nnx__rngs" in path_keys or any(k.endswith("_rngs") for k in path_keys):
909925
continue

src/maxtext/configs/post_train/lora_module_path.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ llama3.1: "decoder/layers/.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1
1919
qwen3: "decoder/layers/self_attention/(query|key|value|out)|decoder/layers/mlp/(wi_0|wi_1|wo)"
2020
mistral: "decoder/layers/.*(attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))"
2121
deepseek2: "decoder/(dense_layers|moe_stack)/self_attention/(query|out|wkv_a|wkv_b)|decoder/(dense_layers|moe_stack)/(mlp|shared_experts)/(wi_0|wi_1|wo)"
22-
gemma2: "decoder/layers/(self_attention_local|self_attention_global)/(query|key|value|out)|decoder/layers/(mlp_local|mlp_global)/(wi_0|wi_1|wo)"
23-
gemma3: "decoder/layers/.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo|gate|up|down))"
22+
gemma2: "decoder/(scanned_blocks|layers_remainder|layers)/(self_attention_local|self_attention_global)/(query|key|value|out)|decoder/(scanned_blocks|layers_remainder|layers)/(mlp_local|mlp_global)/(wi_0|wi_1|wo)"
23+
gemma3: "decoder/(scanned_blocks|layers_remainder|layers)/.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo|gate|up|down))"
2424
olmo3: "decoder/layers/.*(attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))"
2525
gpt3: "decoder/layers/(self_attention/(qkv_proj|out)|mlp/(wi|wo))"
2626

0 commit comments

Comments
 (0)