Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 53 additions & 13 deletions src/maxtext/checkpoint_conversion/to_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,35 @@ def _get_lora_delta(key, lora_state_dict, lora_scaling):
a_key, b_key = key[7:] + "_lora_a", key[7:] + "_lora_b"

if a_key in lora_state_dict and b_key in lora_state_dict:
data_a, data_b = jnp.asarray(lora_state_dict[a_key], dtype=jnp.float32), jnp.asarray(
lora_state_dict[b_key], dtype=jnp.float32
)
if data_a.ndim > 2:
data_a = jnp.asarray(lora_state_dict[a_key], dtype=jnp.float32)
data_b = jnp.asarray(lora_state_dict[b_key], dtype=jnp.float32)

is_attention = "attention" in key.lower() or "attn" in key.lower()

if is_attention and data_a.ndim > 2:
if data_a.ndim == 4:
# Scanned attention projection: [num_layers, input_dim, heads, rank] & [num_layers, rank, heads, output_dim]
return jnp.einsum("lipr,lrpo->lipo", data_a, data_b) * lora_scaling
# Unscanned attention projection: [input_dim, heads, rank] & [rank, heads, output_dim]
return jnp.einsum("ipr,rpo->ipo", data_a, data_b) * lora_scaling
return jnp.matmul(data_a, data_b) * lora_scaling
else:
if data_a.ndim == 3:
# Scanned standard linear projection: can be [num_layers, input_dim, rank] or [input_dim, num_layers, rank]
rank = data_a.shape[2]
if rank == data_b.shape[1] and rank != data_b.shape[0]:
# Case A: [num_layers, input_dim, rank] & [num_layers, rank, output_dim]
return jnp.einsum("lir,lro->lio", data_a, data_b) * lora_scaling
elif rank == data_b.shape[0] and rank != data_b.shape[1]:
# Case B: [input_dim, num_layers, rank] & [rank, num_layers, output_dim]
return jnp.einsum("ilr,rlo->ilo", data_a, data_b) * lora_scaling
else:
# Disambiguate using key names (Case B is typically 'wo' or 'out-kernel' / 'out_proj')
if any(term in key for term in ["wo", "out-kernel", "out_proj"]):
return jnp.einsum("ilr,rlo->ilo", data_a, data_b) * lora_scaling
else:
return jnp.einsum("lir,lro->lio", data_a, data_b) * lora_scaling
# Unscanned standard linear projection
return jnp.matmul(data_a, data_b) * lora_scaling
return None


Expand Down Expand Up @@ -286,19 +309,38 @@ def _transform_weights_to_adapter(param_map, state_dict):
if a_key in state_dict and b_key in state_dict:
data_a, data_b = state_dict[a_key], state_dict[b_key]
hf_paths = [hf_paths] if not isinstance(hf_paths, list) else hf_paths
for i in range(min(data_a.shape[1] if data_a.ndim > 2 else 1, len(hf_paths))):
found_hf_modules.add(hf_paths[i].split(".")[-2])
name = hf_paths[i].replace(".weight", "")
for i, hf_path in enumerate(hf_paths):
found_hf_modules.add(hf_path.split(".")[-2])
name = hf_path.replace(".weight", "")

if data_a.ndim > 2:
if data_a.shape[0] == len(hf_paths):
# Case A: layer dimension is axis 0
layer_a = data_a[i, ...]
layer_b = data_b[i, ...]
else:
# Case B: layer dimension is axis 1
layer_a = data_a[:, i, ...]
layer_b = data_b[:, i, ...]
else:
layer_a = data_a
layer_b = data_b

if layer_a.ndim > 2:
layer_a = layer_a[:, 0, :]
if layer_b.ndim > 2:
layer_b = layer_b[:, 0, :]

processed_params_list.append(
(
f"base_model.model.{name}.lora_A.weight",
jax.numpy.asarray((data_a[:, i, :] if data_a.ndim > 2 else data_a).T),
jax.numpy.asarray(layer_a.T),
)
)
processed_params_list.append(
(
f"base_model.model.{name}.lora_B.weight",
jax.numpy.asarray((data_b[:, i, :] if data_b.ndim > 2 else data_b).T),
jax.numpy.asarray(layer_b.T),
)
)
return dict(processed_params_list), found_hf_modules
Expand Down Expand Up @@ -424,9 +466,7 @@ def main(argv: Sequence[str]) -> None:
maxtext_state_dict = detect_and_extract_checkpoint(checkpoint_dict)

# Validate that checkpoint keys match the parameter mapping
state_keys = set(maxtext_state_dict) | {
k.replace("_lora_a", "").replace("_lora_b", "") for k in maxtext_state_dict if "_lora_" in k
}
state_keys = {k.replace("_lora_a", "").replace("_lora_b", "") for k in maxtext_state_dict}
filtered_map_keys = validate_and_filter_param_map_keys(param_map, state_keys)

# When not converting a multimodal model, skip vision encoder weights even if
Expand Down
22 changes: 19 additions & 3 deletions src/maxtext/checkpoint_conversion/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,16 @@ def format_meter(
return super().format_meter(n=n, total=total, elapsed=elapsed, postfix=postfix, **extra_kwargs)


def _recursive_update(d: dict, u: dict) -> dict:
"""Recursively updates dictionary d with dictionary u in place."""
for k, v in u.items():
if isinstance(v, dict) and isinstance(d.get(k), dict):
_recursive_update(d[k], v)
else:
d[k] = v
return d


def load_orbax_checkpoint(config) -> dict:
"""Loads Orbax checkpoints from Base and/or LoRA paths in config.

Expand Down Expand Up @@ -852,15 +862,21 @@ def create_restore_args(tree_metadata):
paths = [p for p in [config.load_parameters_path, lora_path] if p]

merged_dict = {}
for path in paths:
for i, path in enumerate(paths):
checkpoint_path = epath.Path(path)
metadata = ckptr.metadata(checkpoint_path)
restore_args = jax.tree_util.tree_map(
lambda x: create_restore_args(x) if hasattr(x, "shape") else None,
metadata.item_metadata.tree,
is_leaf=lambda x: hasattr(x, "shape"),
)
merged_dict.update(ckptr.restore(checkpoint_path, restore_args=restore_args))
restored = ckptr.restore(checkpoint_path, restore_args=restore_args)

if i == 0:
merged_dict = restored
else:
# Recursively update base checkpoint with LoRA adapter checkpoint keys to avoid overwriting
_recursive_update(merged_dict, restored)

return merged_dict

Expand Down Expand Up @@ -903,7 +919,7 @@ def extract_nnx_weights(weights_dict: dict) -> dict[str, np.ndarray]:
result = {}
leaves_with_paths = jax.tree_util.tree_leaves_with_path(weights_dict)
for path_tuple, leaf_value in leaves_with_paths:
path_keys = [k.key for k in path_tuple]
path_keys = [str(k.key) for k in path_tuple]
# Skip NNX RNG state variables (not model weights)
if "to_nnx__rngs" in path_keys or any(k.endswith("_rngs") for k in path_keys):
continue
Expand Down
4 changes: 2 additions & 2 deletions src/maxtext/configs/post_train/lora_module_path.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ llama3.1: "decoder/layers/.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1
qwen3: "decoder/layers/self_attention/(query|key|value|out)|decoder/layers/mlp/(wi_0|wi_1|wo)"
mistral: "decoder/layers/.*(attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))"
deepseek2: "decoder/(dense_layers|moe_stack)/self_attention/(query|out|wkv_a|wkv_b)|decoder/(dense_layers|moe_stack)/(mlp|shared_experts)/(wi_0|wi_1|wo)"
gemma2: "decoder/layers/(self_attention_local|self_attention_global)/(query|key|value|out)|decoder/layers/(mlp_local|mlp_global)/(wi_0|wi_1|wo)"
gemma3: "decoder/layers/.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo|gate|up|down))"
gemma2: "decoder/(scanned_blocks|layers_remainder|layers)/(self_attention_local|self_attention_global)/(query|key|value|out)|decoder/(scanned_blocks|layers_remainder|layers)/(mlp_local|mlp_global)/(wi_0|wi_1|wo)"
gemma3: "decoder/(scanned_blocks|layers_remainder|layers)/.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo|gate|up|down))"
olmo3: "decoder/layers/.*(attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))"
gpt3: "decoder/layers/(self_attention/(qkv_proj|out)|mlp/(wi|wo))"

Expand Down
Loading
Loading