2828 convert_hf_lora_key_to_maxtext ,
2929 _process_and_stack_weights ,
3030)
31+ from maxtext .checkpoint_conversion .utils .utils import (
32+ _recursive_update ,
33+ load_orbax_checkpoint ,
34+ )
3135
3236
3337class HFCheckpointConversionTest (unittest .TestCase ):
@@ -105,6 +109,48 @@ def test_transform_weights_to_adapter(self):
105109 self .assertEqual (weights ["base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight" ].shape , (20 , 4 ))
106110 self .assertIn ("q_proj" , modules )
107111
112+ # 1. Scanned standard linear Case A (3D): [num_layers, input_dim, rank] & [num_layers, rank, output_dim]
113+ param_map_scanned_a = {
114+ "params-decoder-scanned_blocks-mlp-wi_0-kernel" : [
115+ "model.layers.0.mlp.gate_proj.weight" ,
116+ "model.layers.1.mlp.gate_proj.weight" ,
117+ ]
118+ }
119+ # num_layers = 2, input_dim = 10, rank = 4, output_dim = 20
120+ data_a_scanned_a = np .ones ((2 , 10 , 4 ), dtype = np .float32 ) * 0.5
121+ data_b_scanned_a = np .ones ((2 , 4 , 20 ), dtype = np .float32 ) * 0.5
122+ lora_dict_scanned_a = {
123+ "params-decoder-scanned_blocks-mlp-wi_0-kernel_lora_a" : data_a_scanned_a ,
124+ "params-decoder-scanned_blocks-mlp-wi_0-kernel_lora_b" : data_b_scanned_a ,
125+ }
126+ weights_sa , _ = _transform_weights_to_adapter (param_map_scanned_a , lora_dict_scanned_a )
127+ self .assertIn ("base_model.model.model.layers.0.mlp.gate_proj.lora_A.weight" , weights_sa )
128+ self .assertIn ("base_model.model.model.layers.0.mlp.gate_proj.lora_B.weight" , weights_sa )
129+ # Since layer dimension is axis 0, layer 0 is data_a_scanned_a[0, :, :], which has shape (10, 4), transpose -> (4, 10)
130+ self .assertEqual (weights_sa ["base_model.model.model.layers.0.mlp.gate_proj.lora_A.weight" ].shape , (4 , 10 ))
131+ self .assertEqual (weights_sa ["base_model.model.model.layers.0.mlp.gate_proj.lora_B.weight" ].shape , (20 , 4 ))
132+
133+ # 2. Scanned standard linear Case B (3D): [input_dim, num_layers, rank] & [rank, num_layers, output_dim]
134+ param_map_scanned_b = {
135+ "params-decoder-scanned_blocks-mlp-wo-kernel" : [
136+ "model.layers.0.mlp.down_proj.weight" ,
137+ "model.layers.1.mlp.down_proj.weight" ,
138+ ]
139+ }
140+ # num_layers = 2, input_dim = 10, rank = 4, output_dim = 20
141+ data_a_scanned_b = np .ones ((10 , 2 , 4 ), dtype = np .float32 ) * 0.5
142+ data_b_scanned_b = np .ones ((4 , 2 , 20 ), dtype = np .float32 ) * 0.5
143+ lora_dict_scanned_b = {
144+ "params-decoder-scanned_blocks-mlp-wo-kernel_lora_a" : data_a_scanned_b ,
145+ "params-decoder-scanned_blocks-mlp-wo-kernel_lora_b" : data_b_scanned_b ,
146+ }
147+ weights_sb , _ = _transform_weights_to_adapter (param_map_scanned_b , lora_dict_scanned_b )
148+ self .assertIn ("base_model.model.model.layers.0.mlp.down_proj.lora_A.weight" , weights_sb )
149+ self .assertIn ("base_model.model.model.layers.0.mlp.down_proj.lora_B.weight" , weights_sb )
150+ # Since layer dimension is axis 1, layer 0 is data_a_scanned_b[:, 0, :], which has shape (10, 4), transpose -> (4, 10)
151+ self .assertEqual (weights_sb ["base_model.model.model.layers.0.mlp.down_proj.lora_A.weight" ].shape , (4 , 10 ))
152+ self .assertEqual (weights_sb ["base_model.model.model.layers.0.mlp.down_proj.lora_B.weight" ].shape , (20 , 4 ))
153+
108154 def test_transform_weights_to_full_model_merged (self ):
109155 config = MagicMock ()
110156 config .lora .lora_alpha = 32.0
@@ -125,6 +171,49 @@ def test_transform_weights_to_full_model_merged(self):
125171 self .assertIn ("model.layers.0.self_attn.q_proj.weight" , weights )
126172 self .assertTrue (np .allclose (weights ["model.layers.0.self_attn.q_proj.weight" ], self .expected_merged_val ))
127173
174+ def test_get_lora_delta_scanned_and_unscanned_variants (self ):
175+ cases = [
176+ # (name, key, shape_a, shape_b, expected_shape, expected_val)
177+ ("2d_linear" , "params-decoder-layers-layers_0-mlp-wi_0-kernel" , (10 , 4 ), (4 , 20 ), (10 , 20 ), 2.0 ),
178+ (
179+ "3d_unscanned_attn" ,
180+ "params-decoder-layers-layers_0-self_attention-query-kernel" ,
181+ (10 , 2 , 4 ),
182+ (4 , 2 , 20 ),
183+ (10 , 2 , 20 ),
184+ 2.0 ,
185+ ),
186+ (
187+ "3d_scanned_linear_a" ,
188+ "params-decoder-scanned_blocks-mlp-wi_0-kernel" ,
189+ (3 , 10 , 4 ),
190+ (3 , 4 , 20 ),
191+ (3 , 10 , 20 ),
192+ 2.0 ,
193+ ),
194+ ("3d_scanned_linear_b" , "params-decoder-scanned_blocks-mlp-wo-kernel" , (10 , 3 , 4 ), (4 , 3 , 20 ), (10 , 3 , 20 ), 2.0 ),
195+ (
196+ "4d_scanned_attn" ,
197+ "params-decoder-scanned_blocks-self_attention-query-kernel" ,
198+ (3 , 10 , 2 , 4 ),
199+ (3 , 4 , 2 , 20 ),
200+ (3 , 10 , 2 , 20 ),
201+ 2.0 ,
202+ ),
203+ ("edge_case_a" , "params-decoder-scanned_blocks-mlp-wi_0-kernel" , (3 , 3 , 3 ), (3 , 3 , 20 ), (3 , 3 , 20 ), 1.5 ),
204+ ("edge_case_b" , "params-decoder-scanned_blocks-mlp-wo-kernel" , (3 , 3 , 3 ), (3 , 3 , 20 ), (3 , 3 , 20 ), 1.5 ),
205+ ]
206+
207+ for name , key , shape_a , shape_b , expected_shape , expected_val in cases :
208+ with self .subTest (name = name ):
209+ state_dict = {
210+ f"{ key } _lora_a" : np .ones (shape_a , dtype = np .float32 ) * 0.5 ,
211+ f"{ key } _lora_b" : np .ones (shape_b , dtype = np .float32 ) * 0.5 ,
212+ }
213+ delta = _get_lora_delta (key , state_dict , 2.0 )
214+ self .assertEqual (delta .shape , expected_shape )
215+ self .assertTrue (np .allclose (delta , expected_val ))
216+
128217
129218class HFToMaxTextLoRAConversionTest (unittest .TestCase ):
130219 """Tests the conversion logic in to_maxtext with LoRA support."""
@@ -169,5 +258,106 @@ def test_process_and_stack_weights(self):
169258 self .assertEqual (stacked [1 , 0 , 0 ], 2.0 )
170259
171260
261+ class CheckpointMergingTest (unittest .TestCase ):
262+ """Tests the recursive_update and load_orbax_checkpoint functions to ensure we don't overwrite weights."""
263+
264+ def test_recursive_update (self ):
265+
266+ base = {
267+ "params" : {
268+ "decoder" : {
269+ "layers" : {
270+ "kernel" : np .ones ((4 , 4 )),
271+ }
272+ }
273+ }
274+ }
275+ lora = {
276+ "params" : {
277+ "decoder" : {
278+ "layers" : {
279+ "kernel_lora_a" : np .ones ((4 , 2 )),
280+ "kernel_lora_b" : np .ones ((2 , 4 )),
281+ }
282+ }
283+ }
284+ }
285+
286+ merged = {}
287+ _recursive_update (merged , base )
288+ _recursive_update (merged , lora )
289+
290+ # Verify that both base and lora weights are present and not overwritten
291+ self .assertIn ("kernel" , merged ["params" ]["decoder" ]["layers" ])
292+ self .assertIn ("kernel_lora_a" , merged ["params" ]["decoder" ]["layers" ])
293+ self .assertIn ("kernel_lora_b" , merged ["params" ]["decoder" ]["layers" ])
294+ np .testing .assert_array_equal (merged ["params" ]["decoder" ]["layers" ]["kernel" ], np .ones ((4 , 4 )))
295+ np .testing .assert_array_equal (merged ["params" ]["decoder" ]["layers" ]["kernel_lora_a" ], np .ones ((4 , 2 )))
296+ np .testing .assert_array_equal (merged ["params" ]["decoder" ]["layers" ]["kernel_lora_b" ], np .ones ((2 , 4 )))
297+
298+ @unittest .mock .patch ("maxtext.checkpoint_conversion.utils.utils.ocp.Checkpointer" )
299+ @unittest .mock .patch ("maxtext.checkpoint_conversion.utils.utils.epath.Path" )
300+ @unittest .mock .patch ("maxtext.checkpoint_conversion.utils.utils.jax.devices" )
301+ def test_load_orbax_checkpoint_recursive_merge (self , mock_jax_devices , mock_path , mock_checkpointer_cls ):
302+
303+ # Mock jax devices
304+ mock_jax_devices .return_value = [MagicMock ()]
305+
306+ # Mock Orbax Checkpointer and its restore results
307+ mock_ckptr = MagicMock ()
308+ mock_checkpointer_cls .return_value = mock_ckptr
309+
310+ # Base checkpoint metadata and content
311+ base_metadata = MagicMock ()
312+ base_metadata .item_metadata .tree = {"params" : {"decoder" : {"layers" : {"kernel" : MagicMock (shape = (4 , 4 ))}}}}
313+ base_restore_content = {"params" : {"decoder" : {"layers" : {"kernel" : np .ones ((4 , 4 ))}}}}
314+
315+ # LoRA checkpoint metadata and content
316+ lora_metadata = MagicMock ()
317+ lora_metadata .item_metadata .tree = {
318+ "params" : {
319+ "decoder" : {
320+ "layers" : {
321+ "kernel_lora_a" : MagicMock (shape = (4 , 2 )),
322+ "kernel_lora_b" : MagicMock (shape = (2 , 4 )),
323+ }
324+ }
325+ }
326+ }
327+ lora_restore_content = {
328+ "params" : {
329+ "decoder" : {
330+ "layers" : {
331+ "kernel_lora_a" : np .ones ((4 , 2 )),
332+ "kernel_lora_b" : np .ones ((2 , 4 )),
333+ }
334+ }
335+ }
336+ }
337+
338+ # Mock metadata and restore calls
339+ mock_ckptr .metadata .side_effect = [base_metadata , lora_metadata ]
340+ mock_ckptr .restore .side_effect = [base_restore_content , lora_restore_content ]
341+
342+ # Create dummy config
343+ config = MagicMock ()
344+ config .checkpoint_storage_concurrent_gb = 8
345+ config .checkpoint_storage_use_ocdbt = True
346+ config .checkpoint_storage_use_zarr3 = True
347+ config .load_parameters_path = "gs://base-bucket/checkpoints"
348+ config .lora .lora_restore_path = "gs://lora-bucket/checkpoints"
349+
350+ # Load and merge
351+ merged = load_orbax_checkpoint (config )
352+
353+ # Assert checkpointer was called twice and restored both
354+ self .assertEqual (mock_ckptr .restore .call_count , 2 )
355+
356+ # Verify that the keys are recursively merged correctly!
357+ self .assertIn ("kernel" , merged ["params" ]["decoder" ]["layers" ])
358+ self .assertIn ("kernel_lora_a" , merged ["params" ]["decoder" ]["layers" ])
359+ self .assertIn ("kernel_lora_b" , merged ["params" ]["decoder" ]["layers" ])
360+
361+
172362if __name__ == "__main__" :
173363 unittest .main ()
0 commit comments