From 7c0b0f62edcf12bacb7eb3074c207a460cc4d604 Mon Sep 17 00:00:00 2001 From: Jacky Fang Date: Tue, 2 Jun 2026 08:43:34 +0000 Subject: [PATCH] fix: Checkpoint converter fixes for loading, merging, and recursively updating base and LoRA checkpoints --- .../checkpoint_conversion/to_huggingface.py | 66 ++++- .../checkpoint_conversion/utils/utils.py | 22 +- .../configs/post_train/lora_module_path.yml | 4 +- tests/unit/hf_checkpoint_conversion_test.py | 263 ++++++++++++++++++ 4 files changed, 337 insertions(+), 18 deletions(-) diff --git a/src/maxtext/checkpoint_conversion/to_huggingface.py b/src/maxtext/checkpoint_conversion/to_huggingface.py index 0c498d7c3a..03f203d0c3 100644 --- a/src/maxtext/checkpoint_conversion/to_huggingface.py +++ b/src/maxtext/checkpoint_conversion/to_huggingface.py @@ -116,12 +116,35 @@ def _get_lora_delta(key, lora_state_dict, lora_scaling): a_key, b_key = key[7:] + "_lora_a", key[7:] + "_lora_b" if a_key in lora_state_dict and b_key in lora_state_dict: - data_a, data_b = jnp.asarray(lora_state_dict[a_key], dtype=jnp.float32), jnp.asarray( - lora_state_dict[b_key], dtype=jnp.float32 - ) - if data_a.ndim > 2: + data_a = jnp.asarray(lora_state_dict[a_key], dtype=jnp.float32) + data_b = jnp.asarray(lora_state_dict[b_key], dtype=jnp.float32) + + is_attention = "attention" in key.lower() or "attn" in key.lower() + + if is_attention and data_a.ndim > 2: + if data_a.ndim == 4: + # Scanned attention projection: [num_layers, input_dim, heads, rank] & [num_layers, rank, heads, output_dim] + return jnp.einsum("lipr,lrpo->lipo", data_a, data_b) * lora_scaling + # Unscanned attention projection: [input_dim, heads, rank] & [rank, heads, output_dim] return jnp.einsum("ipr,rpo->ipo", data_a, data_b) * lora_scaling - return jnp.matmul(data_a, data_b) * lora_scaling + else: + if data_a.ndim == 3: + # Scanned standard linear projection: can be [num_layers, input_dim, rank] or [input_dim, num_layers, rank] + rank = data_a.shape[2] + if rank == data_b.shape[1] and rank != data_b.shape[0]: + # Case A: [num_layers, input_dim, rank] & [num_layers, rank, output_dim] + return jnp.einsum("lir,lro->lio", data_a, data_b) * lora_scaling + elif rank == data_b.shape[0] and rank != data_b.shape[1]: + # Case B: [input_dim, num_layers, rank] & [rank, num_layers, output_dim] + return jnp.einsum("ilr,rlo->ilo", data_a, data_b) * lora_scaling + else: + # Disambiguate using key names (Case B is typically 'wo' or 'out-kernel' / 'out_proj') + if any(term in key for term in ["wo", "out-kernel", "out_proj"]): + return jnp.einsum("ilr,rlo->ilo", data_a, data_b) * lora_scaling + else: + return jnp.einsum("lir,lro->lio", data_a, data_b) * lora_scaling + # Unscanned standard linear projection + return jnp.matmul(data_a, data_b) * lora_scaling return None @@ -286,19 +309,38 @@ def _transform_weights_to_adapter(param_map, state_dict): if a_key in state_dict and b_key in state_dict: data_a, data_b = state_dict[a_key], state_dict[b_key] hf_paths = [hf_paths] if not isinstance(hf_paths, list) else hf_paths - for i in range(min(data_a.shape[1] if data_a.ndim > 2 else 1, len(hf_paths))): - found_hf_modules.add(hf_paths[i].split(".")[-2]) - name = hf_paths[i].replace(".weight", "") + for i, hf_path in enumerate(hf_paths): + found_hf_modules.add(hf_path.split(".")[-2]) + name = hf_path.replace(".weight", "") + + if data_a.ndim > 2: + if data_a.shape[0] == len(hf_paths): + # Case A: layer dimension is axis 0 + layer_a = data_a[i, ...] + layer_b = data_b[i, ...] + else: + # Case B: layer dimension is axis 1 + layer_a = data_a[:, i, ...] + layer_b = data_b[:, i, ...] + else: + layer_a = data_a + layer_b = data_b + + if layer_a.ndim > 2: + layer_a = layer_a[:, 0, :] + if layer_b.ndim > 2: + layer_b = layer_b[:, 0, :] + processed_params_list.append( ( f"base_model.model.{name}.lora_A.weight", - jax.numpy.asarray((data_a[:, i, :] if data_a.ndim > 2 else data_a).T), + jax.numpy.asarray(layer_a.T), ) ) processed_params_list.append( ( f"base_model.model.{name}.lora_B.weight", - jax.numpy.asarray((data_b[:, i, :] if data_b.ndim > 2 else data_b).T), + jax.numpy.asarray(layer_b.T), ) ) return dict(processed_params_list), found_hf_modules @@ -424,9 +466,7 @@ def main(argv: Sequence[str]) -> None: maxtext_state_dict = detect_and_extract_checkpoint(checkpoint_dict) # Validate that checkpoint keys match the parameter mapping - state_keys = set(maxtext_state_dict) | { - k.replace("_lora_a", "").replace("_lora_b", "") for k in maxtext_state_dict if "_lora_" in k - } + state_keys = {k.replace("_lora_a", "").replace("_lora_b", "") for k in maxtext_state_dict} filtered_map_keys = validate_and_filter_param_map_keys(param_map, state_keys) # When not converting a multimodal model, skip vision encoder weights even if diff --git a/src/maxtext/checkpoint_conversion/utils/utils.py b/src/maxtext/checkpoint_conversion/utils/utils.py index 62eead2eb6..c60c280837 100644 --- a/src/maxtext/checkpoint_conversion/utils/utils.py +++ b/src/maxtext/checkpoint_conversion/utils/utils.py @@ -817,6 +817,16 @@ def format_meter( return super().format_meter(n=n, total=total, elapsed=elapsed, postfix=postfix, **extra_kwargs) +def _recursive_update(d: dict, u: dict) -> dict: + """Recursively updates dictionary d with dictionary u in place.""" + for k, v in u.items(): + if isinstance(v, dict) and isinstance(d.get(k), dict): + _recursive_update(d[k], v) + else: + d[k] = v + return d + + def load_orbax_checkpoint(config) -> dict: """Loads Orbax checkpoints from Base and/or LoRA paths in config. @@ -852,7 +862,7 @@ def create_restore_args(tree_metadata): paths = [p for p in [config.load_parameters_path, lora_path] if p] merged_dict = {} - for path in paths: + for i, path in enumerate(paths): checkpoint_path = epath.Path(path) metadata = ckptr.metadata(checkpoint_path) restore_args = jax.tree_util.tree_map( @@ -860,7 +870,13 @@ def create_restore_args(tree_metadata): metadata.item_metadata.tree, is_leaf=lambda x: hasattr(x, "shape"), ) - merged_dict.update(ckptr.restore(checkpoint_path, restore_args=restore_args)) + restored = ckptr.restore(checkpoint_path, restore_args=restore_args) + + if i == 0: + merged_dict = restored + else: + # Recursively update base checkpoint with LoRA adapter checkpoint keys to avoid overwriting + _recursive_update(merged_dict, restored) return merged_dict @@ -903,7 +919,7 @@ def extract_nnx_weights(weights_dict: dict) -> dict[str, np.ndarray]: result = {} leaves_with_paths = jax.tree_util.tree_leaves_with_path(weights_dict) for path_tuple, leaf_value in leaves_with_paths: - path_keys = [k.key for k in path_tuple] + path_keys = [str(k.key) for k in path_tuple] # Skip NNX RNG state variables (not model weights) if "to_nnx__rngs" in path_keys or any(k.endswith("_rngs") for k in path_keys): continue diff --git a/src/maxtext/configs/post_train/lora_module_path.yml b/src/maxtext/configs/post_train/lora_module_path.yml index 11f81d52c5..11d2a31b57 100644 --- a/src/maxtext/configs/post_train/lora_module_path.yml +++ b/src/maxtext/configs/post_train/lora_module_path.yml @@ -19,8 +19,8 @@ llama3.1: "decoder/layers/.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1 qwen3: "decoder/layers/self_attention/(query|key|value|out)|decoder/layers/mlp/(wi_0|wi_1|wo)" mistral: "decoder/layers/.*(attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))" deepseek2: "decoder/(dense_layers|moe_stack)/self_attention/(query|out|wkv_a|wkv_b)|decoder/(dense_layers|moe_stack)/(mlp|shared_experts)/(wi_0|wi_1|wo)" -gemma2: "decoder/layers/(self_attention_local|self_attention_global)/(query|key|value|out)|decoder/layers/(mlp_local|mlp_global)/(wi_0|wi_1|wo)" -gemma3: "decoder/layers/.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo|gate|up|down))" +gemma2: "decoder/(scanned_blocks|layers_remainder|layers)/(self_attention_local|self_attention_global)/(query|key|value|out)|decoder/(scanned_blocks|layers_remainder|layers)/(mlp_local|mlp_global)/(wi_0|wi_1|wo)" +gemma3: "decoder/(scanned_blocks|layers_remainder|layers)/.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo|gate|up|down))" olmo3: "decoder/layers/.*(attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))" gpt3: "decoder/layers/(self_attention/(qkv_proj|out)|mlp/(wi|wo))" diff --git a/tests/unit/hf_checkpoint_conversion_test.py b/tests/unit/hf_checkpoint_conversion_test.py index 02ed7a5598..75fefb150c 100644 --- a/tests/unit/hf_checkpoint_conversion_test.py +++ b/tests/unit/hf_checkpoint_conversion_test.py @@ -28,6 +28,10 @@ convert_hf_lora_key_to_maxtext, _process_and_stack_weights, ) +from maxtext.checkpoint_conversion.utils.utils import ( + _recursive_update, + load_orbax_checkpoint, +) class HFCheckpointConversionTest(unittest.TestCase): @@ -105,6 +109,48 @@ def test_transform_weights_to_adapter(self): self.assertEqual(weights["base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight"].shape, (20, 4)) self.assertIn("q_proj", modules) + # 1. Scanned standard linear Case A (3D): [num_layers, input_dim, rank] & [num_layers, rank, output_dim] + param_map_scanned_a = { + "params-decoder-scanned_blocks-mlp-wi_0-kernel": [ + "model.layers.0.mlp.gate_proj.weight", + "model.layers.1.mlp.gate_proj.weight", + ] + } + # num_layers = 2, input_dim = 10, rank = 4, output_dim = 20 + data_a_scanned_a = np.ones((2, 10, 4), dtype=np.float32) * 0.5 + data_b_scanned_a = np.ones((2, 4, 20), dtype=np.float32) * 0.5 + lora_dict_scanned_a = { + "params-decoder-scanned_blocks-mlp-wi_0-kernel_lora_a": data_a_scanned_a, + "params-decoder-scanned_blocks-mlp-wi_0-kernel_lora_b": data_b_scanned_a, + } + weights_sa, _ = _transform_weights_to_adapter(param_map_scanned_a, lora_dict_scanned_a) + self.assertIn("base_model.model.model.layers.0.mlp.gate_proj.lora_A.weight", weights_sa) + self.assertIn("base_model.model.model.layers.0.mlp.gate_proj.lora_B.weight", weights_sa) + # Since layer dimension is axis 0, layer 0 is data_a_scanned_a[0, :, :], which has shape (10, 4), transpose -> (4, 10) + self.assertEqual(weights_sa["base_model.model.model.layers.0.mlp.gate_proj.lora_A.weight"].shape, (4, 10)) + self.assertEqual(weights_sa["base_model.model.model.layers.0.mlp.gate_proj.lora_B.weight"].shape, (20, 4)) + + # 2. Scanned standard linear Case B (3D): [input_dim, num_layers, rank] & [rank, num_layers, output_dim] + param_map_scanned_b = { + "params-decoder-scanned_blocks-mlp-wo-kernel": [ + "model.layers.0.mlp.down_proj.weight", + "model.layers.1.mlp.down_proj.weight", + ] + } + # num_layers = 2, input_dim = 10, rank = 4, output_dim = 20 + data_a_scanned_b = np.ones((10, 2, 4), dtype=np.float32) * 0.5 + data_b_scanned_b = np.ones((4, 2, 20), dtype=np.float32) * 0.5 + lora_dict_scanned_b = { + "params-decoder-scanned_blocks-mlp-wo-kernel_lora_a": data_a_scanned_b, + "params-decoder-scanned_blocks-mlp-wo-kernel_lora_b": data_b_scanned_b, + } + weights_sb, _ = _transform_weights_to_adapter(param_map_scanned_b, lora_dict_scanned_b) + self.assertIn("base_model.model.model.layers.0.mlp.down_proj.lora_A.weight", weights_sb) + self.assertIn("base_model.model.model.layers.0.mlp.down_proj.lora_B.weight", weights_sb) + # Since layer dimension is axis 1, layer 0 is data_a_scanned_b[:, 0, :], which has shape (10, 4), transpose -> (4, 10) + self.assertEqual(weights_sb["base_model.model.model.layers.0.mlp.down_proj.lora_A.weight"].shape, (4, 10)) + self.assertEqual(weights_sb["base_model.model.model.layers.0.mlp.down_proj.lora_B.weight"].shape, (20, 4)) + def test_transform_weights_to_full_model_merged(self): config = MagicMock() config.lora.lora_alpha = 32.0 @@ -125,6 +171,49 @@ def test_transform_weights_to_full_model_merged(self): self.assertIn("model.layers.0.self_attn.q_proj.weight", weights) self.assertTrue(np.allclose(weights["model.layers.0.self_attn.q_proj.weight"], self.expected_merged_val)) + def test_get_lora_delta_scanned_and_unscanned_variants(self): + cases = [ + # (name, key, shape_a, shape_b, expected_shape, expected_val) + ("2d_linear", "params-decoder-layers-layers_0-mlp-wi_0-kernel", (10, 4), (4, 20), (10, 20), 2.0), + ( + "3d_unscanned_attn", + "params-decoder-layers-layers_0-self_attention-query-kernel", + (10, 2, 4), + (4, 2, 20), + (10, 2, 20), + 2.0, + ), + ( + "3d_scanned_linear_a", + "params-decoder-scanned_blocks-mlp-wi_0-kernel", + (3, 10, 4), + (3, 4, 20), + (3, 10, 20), + 2.0, + ), + ("3d_scanned_linear_b", "params-decoder-scanned_blocks-mlp-wo-kernel", (10, 3, 4), (4, 3, 20), (10, 3, 20), 2.0), + ( + "4d_scanned_attn", + "params-decoder-scanned_blocks-self_attention-query-kernel", + (3, 10, 2, 4), + (3, 4, 2, 20), + (3, 10, 2, 20), + 2.0, + ), + ("edge_case_a", "params-decoder-scanned_blocks-mlp-wi_0-kernel", (3, 3, 3), (3, 3, 20), (3, 3, 20), 1.5), + ("edge_case_b", "params-decoder-scanned_blocks-mlp-wo-kernel", (3, 3, 3), (3, 3, 20), (3, 3, 20), 1.5), + ] + + for name, key, shape_a, shape_b, expected_shape, expected_val in cases: + with self.subTest(name=name): + state_dict = { + f"{key}_lora_a": np.ones(shape_a, dtype=np.float32) * 0.5, + f"{key}_lora_b": np.ones(shape_b, dtype=np.float32) * 0.5, + } + delta = _get_lora_delta(key, state_dict, 2.0) + self.assertEqual(delta.shape, expected_shape) + self.assertTrue(np.allclose(delta, expected_val)) + class HFToMaxTextLoRAConversionTest(unittest.TestCase): """Tests the conversion logic in to_maxtext with LoRA support.""" @@ -169,5 +258,179 @@ def test_process_and_stack_weights(self): self.assertEqual(stacked[1, 0, 0], 2.0) +class CheckpointMergingTest(unittest.TestCase): + """Tests the recursive_update and load_orbax_checkpoint functions to ensure we don't overwrite weights.""" + + def test_recursive_update(self): + + base = { + "params": { + "decoder": { + "layers": { + "kernel": np.ones((4, 4)), + } + } + } + } + lora = { + "params": { + "decoder": { + "layers": { + "kernel_lora_a": np.ones((4, 2)), + "kernel_lora_b": np.ones((2, 4)), + } + } + } + } + + merged = {} + _recursive_update(merged, base) + _recursive_update(merged, lora) + + # Verify that both base and lora weights are present and not overwritten + self.assertIn("kernel", merged["params"]["decoder"]["layers"]) + self.assertIn("kernel_lora_a", merged["params"]["decoder"]["layers"]) + self.assertIn("kernel_lora_b", merged["params"]["decoder"]["layers"]) + np.testing.assert_array_equal(merged["params"]["decoder"]["layers"]["kernel"], np.ones((4, 4))) + np.testing.assert_array_equal(merged["params"]["decoder"]["layers"]["kernel_lora_a"], np.ones((4, 2))) + np.testing.assert_array_equal(merged["params"]["decoder"]["layers"]["kernel_lora_b"], np.ones((2, 4))) + + @unittest.mock.patch("maxtext.checkpoint_conversion.utils.utils.ocp.Checkpointer") + @unittest.mock.patch("maxtext.checkpoint_conversion.utils.utils.epath.Path") + @unittest.mock.patch("maxtext.checkpoint_conversion.utils.utils.jax.devices") + def test_load_orbax_checkpoint_recursive_merge(self, mock_jax_devices, mock_path, mock_checkpointer_cls): + + # Mock jax devices + mock_jax_devices.return_value = [MagicMock()] + + # Mock Orbax Checkpointer and its restore results + mock_ckptr = MagicMock() + mock_checkpointer_cls.return_value = mock_ckptr + + # Base checkpoint metadata and content + base_metadata = MagicMock() + base_metadata.item_metadata.tree = {"params": {"decoder": {"layers": {"kernel": MagicMock(shape=(4, 4))}}}} + base_restore_content = {"params": {"decoder": {"layers": {"kernel": np.ones((4, 4))}}}} + + # LoRA checkpoint metadata and content + lora_metadata = MagicMock() + lora_metadata.item_metadata.tree = { + "params": { + "decoder": { + "layers": { + "kernel_lora_a": MagicMock(shape=(4, 2)), + "kernel_lora_b": MagicMock(shape=(2, 4)), + } + } + } + } + lora_restore_content = { + "params": { + "decoder": { + "layers": { + "kernel_lora_a": np.ones((4, 2)), + "kernel_lora_b": np.ones((2, 4)), + } + } + } + } + + # Mock metadata and restore calls + mock_ckptr.metadata.side_effect = [base_metadata, lora_metadata] + mock_ckptr.restore.side_effect = [base_restore_content, lora_restore_content] + + # Create dummy config + config = MagicMock() + config.checkpoint_storage_concurrent_gb = 8 + config.checkpoint_storage_use_ocdbt = True + config.checkpoint_storage_use_zarr3 = True + config.load_parameters_path = "gs://base-bucket/checkpoints" + config.lora.lora_restore_path = "gs://lora-bucket/checkpoints" + + # Load and merge + merged = load_orbax_checkpoint(config) + + # Assert checkpointer was called twice and restored both + self.assertEqual(mock_ckptr.restore.call_count, 2) + + # Verify that the keys are recursively merged correctly! + self.assertIn("kernel", merged["params"]["decoder"]["layers"]) + self.assertIn("kernel_lora_a", merged["params"]["decoder"]["layers"]) + self.assertIn("kernel_lora_b", merged["params"]["decoder"]["layers"]) + + +class Gemma3And4CheckpointConversionTest(unittest.TestCase): + """Explicitly tests Gemma 3 and Gemma 4 formats for base, adapter-only, and merged weight transformations.""" + + def test_gemma3_base_and_adapter_conversion(self): + # Gemma 3 configuration simulation + # Scanned layers weight shapes + # Query weight: [layers, input_dim, heads, head_dim] + # For Gemma 3: 4D tensor for attention query + key = "params-decoder-scanned_blocks-self_attention-query-kernel" + a_key = key + "_lora_a" + b_key = key + "_lora_b" + + # 4D scanned attention shapes: [num_layers, input_dim, heads, rank] + # num_layers = 2, input_dim = 16, heads = 2, rank = 4, output_dim = 16 + data_a = np.ones((2, 16, 2, 4), dtype=np.float32) * 0.5 + data_b = np.ones((2, 4, 2, 16), dtype=np.float32) * 0.5 + + param_map = { + key: [ + "model.layers.0.self_attn.q_proj.weight", + "model.layers.1.self_attn.q_proj.weight", + ] + } + lora_dict = {a_key: data_a, b_key: data_b} + + # 1. Test Adapter-only transformation + weights, _ = _transform_weights_to_adapter(param_map, lora_dict) + self.assertIn("base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight", weights) + self.assertIn("base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight", weights) + self.assertIn("base_model.model.model.layers.1.self_attn.q_proj.lora_A.weight", weights) + self.assertIn("base_model.model.model.layers.1.self_attn.q_proj.lora_B.weight", weights) + + # 2. Test Delta contraction / Merged transformation math + delta = _get_lora_delta(key, lora_dict, lora_scaling=2.0) + # Expected delta shape matches original query weight shape: [2, 16, 2, 16] + self.assertEqual(delta.shape, (2, 16, 2, 16)) + # Math: einsum("lipr,lrpo->lipo", A, B) * 2.0 + # For each slice: matmul(0.5, 0.5) * rank * scaling = 0.25 * 4 * 2.0 = 2.0 + self.assertTrue(np.allclose(delta, 2.0)) + + def test_gemma4_base_and_adapter_conversion(self): + # Gemma 4 configuration simulation + # Scanned layers standard linear weight shapes (e.g. gate_proj, up_proj) + # 3D scanned linear shape Case A: [num_layers, input_dim, rank] & [num_layers, rank, output_dim] + key = "params-decoder-scanned_blocks-mlp-wi_0-kernel" + a_key = key + "_lora_a" + b_key = key + "_lora_b" + + # num_layers = 2, input_dim = 16, rank = 4, output_dim = 32 + data_a = np.ones((2, 16, 4), dtype=np.float32) * 0.5 + data_b = np.ones((2, 4, 32), dtype=np.float32) * 0.5 + + param_map = { + key: [ + "model.layers.0.mlp.gate_proj.weight", + "model.layers.1.mlp.gate_proj.weight", + ] + } + lora_dict = {a_key: data_a, b_key: data_b} + + # 1. Test Adapter-only transformation + weights, _ = _transform_weights_to_adapter(param_map, lora_dict) + self.assertIn("base_model.model.model.layers.0.mlp.gate_proj.lora_A.weight", weights) + self.assertIn("base_model.model.model.layers.0.mlp.gate_proj.lora_B.weight", weights) + + # 2. Test Delta contraction / Merged transformation math + delta = _get_lora_delta(key, lora_dict, lora_scaling=2.0) + # Expected delta shape matches original gate weight shape: [2, 16, 32] + self.assertEqual(delta.shape, (2, 16, 32)) + # Math: einsum("lir,lro->lio", A, B) * 2.0 -> 0.25 * 4 * 2.0 = 2.0 + self.assertTrue(np.allclose(delta, 2.0)) + + if __name__ == "__main__": unittest.main()