@@ -150,6 +150,13 @@ def make_tree_for_params(
150150 if self .has_mm_params and not params .has_mm_params :
151151 ckpt_params = _add_skip_mm_params (ckpt_params , metadata )
152152
153+ # Reconcile known structural mismatches between model-init and
154+ # checkpoint (e.g. Gemma4 LoRA: empty wrapper stubs from split_params,
155+ # leaf-vs-dict format from nn.share_scope). Only triggers when
156+ # mismatches are detected; Gemma3/legacy paths are unchanged.
157+ if _needs_reconciliation (ckpt_params , self .nested_tree ):
158+ ckpt_params = _reconcile_tree (ckpt_params , self .nested_tree )
159+
153160 # 2. Reformat the nested tree to match the checkpoint structure.
154161 if self .type == _CheckpointType .NESTED :
155162 target_params = ckpt_params # No need to reformat
@@ -170,7 +177,11 @@ def make_tree_for_params(
170177
171178 @functools .cached_property
172179 def has_mm_params (self ) -> bool :
173- return 'vision_encoder' in self .nested_tree
180+ # Check for any known multimodal encoder (vision or audio).
181+ return (
182+ 'vision_encoder' in self .nested_tree
183+ or 'audio_encoder' in self .nested_tree
184+ )
174185
175186 @functools .cached_property
176187 def has_audio_input_embedding (self ) -> bool :
@@ -395,10 +406,22 @@ def _remove_mm_params(params):
395406 # TODO(epot): Once orbax supports partial restore, we would not need to
396407 # load those extra params in the first place.
397408
398- del params ['vision_encoder' ]
399- for k in ('mm_input_projection' , 'mm_soft_embedding_norm' ):
409+ # Vision params
410+ if 'vision_encoder' in params :
411+ del params ['vision_encoder' ]
412+ for k in ('mm_input_projection' , 'mm_soft_embedding_norm' ,
413+ 'mm_pre_projection_norm' , 'mm_input_embedding_extra' ):
414+ if k in params .get ('embedder' , {}):
415+ del params ['embedder' ][k ]
416+
417+ # Audio params (Gemma4)
418+ if 'audio_encoder' in params :
419+ del params ['audio_encoder' ]
420+ for k in ('audio_input_projection' , 'audio_soft_embedding_norm' ,
421+ 'audio_input_embedding' , 'audio_input_embedding_extra' ):
400422 if k in params .get ('embedder' , {}):
401423 del params ['embedder' ][k ]
424+
402425 return params
403426
404427
@@ -407,14 +430,129 @@ def _add_skip_mm_params(params: Params, metadata: _CheckpointTree) -> Params:
407430 params = etree .copy (params )
408431 params_with_mm = metadata .nested_tree
409432
410- params ['vision_encoder' ] = params_with_mm ['vision_encoder' ]
411- for k in ('mm_input_projection' , 'mm_soft_embedding_norm' ):
412- if k in params_with_mm .get ('embedder' , {}):
413- params ['embedder' ][k ] = params_with_mm ['embedder' ][k ]
433+ # Known top-level multimodal encoder keys.
434+ _MM_TOP_LEVEL_KEYS = ('vision_encoder' , 'audio_encoder' )
435+ # Known embedder-level multimodal projection/norm keys.
436+ _MM_EMBEDDER_KEYS = (
437+ # Vision
438+ 'mm_input_projection' ,
439+ 'mm_soft_embedding_norm' ,
440+ 'mm_pre_projection_norm' ,
441+ 'mm_input_embedding_extra' ,
442+ # Audio
443+ 'audio_input_projection' ,
444+ 'audio_soft_embedding_norm' ,
445+ 'audio_input_embedding' ,
446+ 'audio_input_embedding_extra' ,
447+ )
448+
449+ for k in _MM_TOP_LEVEL_KEYS :
450+ if k in params_with_mm and k not in params :
451+ params [k ] = params_with_mm [k ]
452+
453+ embedder_mm = params_with_mm .get ('embedder' , {})
454+ for k in _MM_EMBEDDER_KEYS :
455+ if k in embedder_mm and k not in params .get ('embedder' , {}):
456+ params ['embedder' ][k ] = embedder_mm [k ]
414457
415458 return params
416459
417460
461+ def _needs_reconciliation (params : Params , metadata_tree : Params ) -> bool :
462+ """Returns True if the model params tree has known structural mismatches.
463+
464+ Detects two patterns that arise when LoRA interceptors interact with
465+ models using ``nn.share_scope`` (e.g. Gemma4 FeedForward):
466+
467+ 1. **Empty ``{}`` stubs** left by ``peft.split_params`` at LoRA wrapper
468+ scopes (e.g. ``_LoRAEinsum_0``). These keys exist in the model tree
469+ but not in the checkpoint.
470+ 2. **Leaf-vs-dict format**: ``nn.share_scope`` flattens ``{'w': array}``
471+ to bare ``ArrayImpl`` in the model-init tree, while the checkpoint
472+ keeps the dict format.
473+
474+ This check is intentionally conservative — it returns ``False`` for
475+ Gemma3 and legacy checkpoints, so their restore paths are unchanged.
476+ """
477+ if not isinstance (params , dict ) or not isinstance (metadata_tree , dict ):
478+ return False
479+
480+ for k , p_val in params .items ():
481+ if k not in metadata_tree :
482+ # Key in model but not in checkpoint (e.g. LoRA stub).
483+ if isinstance (p_val , dict ) and not p_val :
484+ return True
485+ continue
486+ m_val = metadata_tree [k ]
487+ # Leaf-vs-dict mismatch.
488+ if not isinstance (p_val , dict ) and isinstance (m_val , dict ):
489+ return True
490+ # Recurse into sub-dicts.
491+ if isinstance (p_val , dict ) and isinstance (m_val , dict ):
492+ if _needs_reconciliation (p_val , m_val ):
493+ return True
494+
495+ return False
496+
497+
498+ def _reconcile_tree (params : Params , metadata_tree : Params ) -> Params :
499+ """Align model-init params tree to match checkpoint metadata structure.
500+
501+ Only called when :func:`_needs_reconciliation` returns ``True``.
502+
503+ Handles two known mismatches between ``model.init()`` and on-disk
504+ checkpoints:
505+
506+ 1. **Empty stubs**: LoRA wrappers (or other interceptors) may leave
507+ empty dict scopes in the params tree that don't exist in the
508+ checkpoint. These are dropped.
509+ 2. **Leaf-vs-dict format**: ``nn.share_scope`` in Gemma4 ``FeedForward``
510+ flattens ``{'w': array}`` to bare ``ArrayImpl`` during model init.
511+ When the checkpoint stores ``{'w': array}``, the leaf is wrapped to
512+ match.
513+
514+ Args:
515+ params: The model-init params tree (may contain stubs / format
516+ mismatches).
517+ metadata_tree: The checkpoint metadata tree (ground-truth structure).
518+
519+ Returns:
520+ A new params tree aligned to the checkpoint metadata structure.
521+ """
522+ if not isinstance (params , dict ) or not isinstance (metadata_tree , dict ):
523+ return params
524+
525+ result = {}
526+ for k in metadata_tree :
527+ if k not in params :
528+ # Key in checkpoint but not in model (e.g. MM params handled by
529+ # _add_skip_mm_params separately) — skip.
530+ continue
531+ p_val = params [k ]
532+ m_val = metadata_tree [k ]
533+
534+ if isinstance (p_val , dict ) and isinstance (m_val , dict ):
535+ # Both dicts — recurse.
536+ reconciled = _reconcile_tree (p_val , m_val )
537+ if reconciled : # Drop if empty after reconciliation.
538+ result [k ] = reconciled
539+ elif not isinstance (p_val , dict ) and isinstance (m_val , dict ):
540+ # Model has leaf (ArrayImpl), checkpoint has dict ({'w': ...}).
541+ # Wrap the leaf to match checkpoint format.
542+ if len (m_val ) == 1 :
543+ inner_key = next (iter (m_val ))
544+ result [k ] = {inner_key : p_val }
545+ else :
546+ result [k ] = p_val # Fallback: keep as-is.
547+ else :
548+ # Both leaves, or model has dict but checkpoint has leaf.
549+ result [k ] = p_val
550+
551+ # Keys in params but NOT in metadata are intentionally dropped.
552+ # This strips LoRA wrapper stubs (_LoRAEinsum_0, etc.).
553+ return result
554+
555+
418556def _is_flat_layout (params : Params ) -> bool :
419557 """Returns True is the structure is the legacy one."""
420558 return (not _is_stacked_layout (params )) and all (
0 commit comments