@@ -116,12 +116,35 @@ def _get_lora_delta(key, lora_state_dict, lora_scaling):
116116 a_key , b_key = key [7 :] + "_lora_a" , key [7 :] + "_lora_b"
117117
118118 if a_key in lora_state_dict and b_key in lora_state_dict :
119- data_a , data_b = jnp .asarray (lora_state_dict [a_key ], dtype = jnp .float32 ), jnp .asarray (
120- lora_state_dict [b_key ], dtype = jnp .float32
121- )
122- if data_a .ndim > 2 :
119+ data_a = jnp .asarray (lora_state_dict [a_key ], dtype = jnp .float32 )
120+ data_b = jnp .asarray (lora_state_dict [b_key ], dtype = jnp .float32 )
121+
122+ is_attention = "attention" in key .lower () or "attn" in key .lower ()
123+
124+ if is_attention and data_a .ndim > 2 :
125+ if data_a .ndim == 4 :
126+ # Scanned attention projection: [num_layers, input_dim, heads, rank] & [num_layers, rank, heads, output_dim]
127+ return jnp .einsum ("lipr,lrpo->lipo" , data_a , data_b ) * lora_scaling
128+ # Unscanned attention projection: [input_dim, heads, rank] & [rank, heads, output_dim]
123129 return jnp .einsum ("ipr,rpo->ipo" , data_a , data_b ) * lora_scaling
124- return jnp .matmul (data_a , data_b ) * lora_scaling
130+ else :
131+ if data_a .ndim == 3 :
132+ # Scanned standard linear projection: can be [num_layers, input_dim, rank] or [input_dim, num_layers, rank]
133+ rank = data_a .shape [2 ]
134+ if rank == data_b .shape [1 ] and rank != data_b .shape [0 ]:
135+ # Case A: [num_layers, input_dim, rank] & [num_layers, rank, output_dim]
136+ return jnp .einsum ("lir,lro->lio" , data_a , data_b ) * lora_scaling
137+ elif rank == data_b .shape [0 ] and rank != data_b .shape [1 ]:
138+ # Case B: [input_dim, num_layers, rank] & [rank, num_layers, output_dim]
139+ return jnp .einsum ("ilr,rlo->ilo" , data_a , data_b ) * lora_scaling
140+ else :
141+ # Disambiguate using key names (Case B is typically 'wo' or 'out-kernel' / 'out_proj')
142+ if any (term in key for term in ["wo" , "out-kernel" , "out_proj" ]):
143+ return jnp .einsum ("ilr,rlo->ilo" , data_a , data_b ) * lora_scaling
144+ else :
145+ return jnp .einsum ("lir,lro->lio" , data_a , data_b ) * lora_scaling
146+ # Unscanned standard linear projection
147+ return jnp .matmul (data_a , data_b ) * lora_scaling
125148 return None
126149
127150
@@ -286,19 +309,38 @@ def _transform_weights_to_adapter(param_map, state_dict):
286309 if a_key in state_dict and b_key in state_dict :
287310 data_a , data_b = state_dict [a_key ], state_dict [b_key ]
288311 hf_paths = [hf_paths ] if not isinstance (hf_paths , list ) else hf_paths
289- for i in range (min (data_a .shape [1 ] if data_a .ndim > 2 else 1 , len (hf_paths ))):
290- found_hf_modules .add (hf_paths [i ].split ("." )[- 2 ])
291- name = hf_paths [i ].replace (".weight" , "" )
312+ for i , hf_path in enumerate (hf_paths ):
313+ found_hf_modules .add (hf_path .split ("." )[- 2 ])
314+ name = hf_path .replace (".weight" , "" )
315+
316+ if data_a .ndim > 2 :
317+ if data_a .shape [0 ] == len (hf_paths ):
318+ # Case A: layer dimension is axis 0
319+ layer_a = data_a [i , ...]
320+ layer_b = data_b [i , ...]
321+ else :
322+ # Case B: layer dimension is axis 1
323+ layer_a = data_a [:, i , ...]
324+ layer_b = data_b [:, i , ...]
325+ else :
326+ layer_a = data_a
327+ layer_b = data_b
328+
329+ if layer_a .ndim > 2 :
330+ layer_a = layer_a [:, 0 , :]
331+ if layer_b .ndim > 2 :
332+ layer_b = layer_b [:, 0 , :]
333+
292334 processed_params_list .append (
293335 (
294336 f"base_model.model.{ name } .lora_A.weight" ,
295- jax .numpy .asarray (( data_a [:, i , :] if data_a . ndim > 2 else data_a ) .T ),
337+ jax .numpy .asarray (layer_a .T ),
296338 )
297339 )
298340 processed_params_list .append (
299341 (
300342 f"base_model.model.{ name } .lora_B.weight" ,
301- jax .numpy .asarray (( data_b [:, i , :] if data_b . ndim > 2 else data_b ) .T ),
343+ jax .numpy .asarray (layer_b .T ),
302344 )
303345 )
304346 return dict (processed_params_list ), found_hf_modules
@@ -424,9 +466,7 @@ def main(argv: Sequence[str]) -> None:
424466 maxtext_state_dict = detect_and_extract_checkpoint (checkpoint_dict )
425467
426468 # Validate that checkpoint keys match the parameter mapping
427- state_keys = set (maxtext_state_dict ) | {
428- k .replace ("_lora_a" , "" ).replace ("_lora_b" , "" ) for k in maxtext_state_dict if "_lora_" in k
429- }
469+ state_keys = {k .replace ("_lora_a" , "" ).replace ("_lora_b" , "" ) for k in maxtext_state_dict }
430470 filtered_map_keys = validate_and_filter_param_map_keys (param_map , state_keys )
431471
432472 # When not converting a multimodal model, skip vision encoder weights even if
0 commit comments