@@ -485,8 +485,9 @@ def init_meta_tensor(t: torch.Tensor):
485485 # post_load_* hooks itself, so the shared post-load block below
486486 # must skip them. RW handles them inside `mem_pool_scope` so the
487487 # committed pool reflects the post-post_load layout; RO runs
488- # `module.post_load_weights()` before `materialize_module` to
489- # wire aliases prior to zero-copy mapping.
488+ # `setup_aliases()` before `materialize_module` to wire aliases
489+ # prior to zero-copy mapping, then refreshes derived state after
490+ # real GMS tensors are bound.
490491 gms_post_load_handled = False
491492 if load_format == LoadFormat .AUTO :
492493 # Pass model= so format-specific loaders (e.g. MX) can
@@ -717,29 +718,31 @@ def init_meta_tensor_in_pool(t: torch.Tensor):
717718 # Hook order:
718719 # 1. `post_load_apply`: format-specific apply
719720 # work (e.g., MX preshard markers).
720- # 2. Per-module `post_load_weights`: creates
721- # aliases/derived parameter attributes BEFORE
722- # `materialize_module` walks the final module
723- # tree (including `draft_model` for spec dec).
724- # 3. `materialize_module`: zero-copy bind GMS
721+ # 2. Per-module `setup_aliases`: creates structural
722+ # aliases BEFORE `materialize_module` walks the
723+ # final module tree (including `draft_model` for
724+ # spec dec).
725+ # 3. SourceIdentity gate: STRICT pre-materialize
726+ # compatibility check (GMS has no disk fallback).
727+ # 4. `materialize_module`: zero-copy bind GMS
725728 # pool storage onto the model parameters.
726- # 4. `post_load_publish`: any receiver-side
729+ # 5. Per-module `cache_derived_state`: recompute
730+ # Python-side state from real, materialized
731+ # tensors without re-running one-shot transforms.
732+ # 6. `post_load_publish`: any receiver-side
727733 # publish (no-op via the receiver guard).
728734 checkpoint_loader .post_load_apply (
729735 model , weights_preloaded = True )
730736
731- for module in model .modules ():
732- if hasattr (module ,
733- 'post_load_weights' ) and not getattr (
734- module , '_weights_removed' , False ):
735- module .post_load_weights ()
737+ self ._setup_aliases (model )
736738
737739 # Pre-materialize compatibility gate. GMS has no
738740 # disk-fallback path, so a mismatch raises under STRICT
739741 # rather than falling back.
740742 self ._check_gms_source_identity (gms_backend )
741743
742744 gms_backend .materialize_module (model )
745+ self ._walk_cache_state (model )
743746
744747 checkpoint_loader .post_load_publish (
745748 model ,
@@ -829,22 +832,24 @@ def _check_gms_source_identity(self, gms_backend) -> None:
829832
830833 @staticmethod
831834 def _setup_aliases (model : DecoderModelForCausalLM ) -> None :
832- """Run top-level structural alias setup if the model defines it .
835+ """Run structural alias setup on eligible modules .
833836
834- Alias wiring is a model-level concern. It is intentionally not a
835- recursive module walk, because migrated aliases are expected to be set
836- by the root model that owns the layer graph .
837+ The walk is duck-typed so modules can opt in without inheriting a
838+ shared base class. Modules whose weights were removed are skipped,
839+ matching the legacy full post-load walk .
837840
838841 Args:
839- model: Root decoder model whose top-level alias hook should run .
842+ model: Root decoder model whose module tree should be visited .
840843
841844 Returns:
842845 None.
843846 """
844- setup_aliases : Optional [Callable [[], None ]] = getattr (
845- model , 'setup_aliases' , None )
846- if setup_aliases is not None :
847- setup_aliases ()
847+ for module in model .modules ():
848+ setup_aliases : Optional [Callable [[], None ]] = getattr (
849+ module , 'setup_aliases' , None )
850+ if setup_aliases is not None and not getattr (
851+ module , '_weights_removed' , False ):
852+ setup_aliases ()
848853
849854 @staticmethod
850855 def _walk_transform (model : DecoderModelForCausalLM ) -> None :
@@ -935,8 +940,11 @@ def reload(self,
935940 """Reload model weights without running post-load hooks.
936941
937942 Reload is used by incremental update paths that may provide only a
938- partial set of replacement weights. The owner of the update lifecycle is
939- responsible for running post-load processing once all bytes are present.
943+ partial set of replacement weights. Full reloads reset transform guards
944+ before rebinding fresh weights. Partial reloads keep existing transform
945+ guards intact because untouched modules may already contain transformed
946+ live weights. The owner of the update lifecycle is responsible for
947+ running post-load processing once all bytes are present.
940948
941949 Args:
942950 model: Model instance receiving the replacement weights.
@@ -952,6 +960,8 @@ def reload(self,
952960 "Cannot reload weights: weight_mapper was not initialized. "
953961 "This can happen when the initial load used GMS, MX P2P, or "
954962 "VISION_ONLY, which bypass the standard weight mapping path." )
963+ if not allow_partial_loading :
964+ self ._reset_weights_transformed (model )
955965 self ._call_load_weights (model .load_weights ,
956966 weights ,
957967 self .weight_mapper ,
0 commit comments