@@ -34,6 +34,12 @@ def to(self, *args, **kwargs):
3434 def load_weights (self , weights , mapper ):
3535 self ._events .append ("load_weights" )
3636
37+ def setup_aliases (self ):
38+ self ._events .append ("setup_aliases" )
39+
40+ def cache_derived_state (self ):
41+ self ._events .append ("cache_derived_state" )
42+
3743 def post_load_weights (self ):
3844 self ._events .append ("post_load_weights" )
3945
@@ -66,11 +72,21 @@ def _make_loader(monkeypatch, *, events, spec_config=None):
6672 loader ._call_load_weights = MagicMock (
6773 side_effect = lambda fn , weights , mapper , ** kwargs : fn (weights , mapper )
6874 )
69- loader ._load_and_validate_config = MagicMock (return_value = SimpleNamespace (name = "config" ))
75+ loader ._load_and_validate_config = MagicMock (
76+ return_value = SimpleNamespace (name = "config" , mapping = SimpleNamespace ())
77+ )
7078
7179 monkeypatch .setattr (model_loader_mod , "timing" , lambda * _args , ** _kwargs : nullcontext ())
7280 monkeypatch .setattr (model_loader_mod , "maybe_create_moe_load_balancer" , _moe_context )
7381 monkeypatch .setattr (model_loader_mod , "MetaInitMode" , lambda : nullcontext ())
82+ # GMS builds a receiver-side SourceIdentity from the resolved ModelConfig.
83+ # These tests stub the config, so short-circuit fingerprint construction to
84+ # a sentinel; identity-comparison behavior is covered in test_source_identity.py.
85+ monkeypatch .setattr (
86+ model_loader_mod .SourceIdentity ,
87+ "from_model_config" ,
88+ classmethod (lambda cls , * _args , ** _kwargs : SimpleNamespace (name = "local-identity" )),
89+ )
7490 monkeypatch .setattr (
7591 model_loader_mod .AutoModelForCausalLM ,
7692 "from_config" ,
@@ -93,6 +109,7 @@ def _build_gms_backend(*, is_rw, events):
93109 if is_rw :
94110 backend .mem_pool_scope .side_effect = lambda _device : _pool_scope (events )
95111 else :
112+ backend .get_source_identity .return_value = None
96113
97114 def _materialize (_model ):
98115 events .append ("materialize" )
@@ -129,7 +146,7 @@ def _spec_config_needing_draft_weights():
129146 ),
130147 pytest .param (
131148 False ,
132- ["post_load_weights " , "materialize" ],
149+ ["setup_aliases " , "materialize" , "cache_derived_state " ],
133150 id = "ro" ,
134151 ),
135152 ],
@@ -143,8 +160,9 @@ def test_gms_load_branch(monkeypatch, is_rw, expected_events):
143160 (``_apply`` for meta materialization, ``to('cuda')``, weight
144161 load, ``post_load_weights``) inside the pool, then commits via
145162 ``finalize_write`` once the scope exits.
146- ro: the reader runs ``post_load_weights`` to wire module aliases
147- first, then GMS materializes weights via zero-copy mapping.
163+ ro: the reader runs ``setup_aliases`` to wire module aliases, checks
164+ identity compatibility, materializes weights via zero-copy mapping,
165+ then refreshes derived state from real tensors.
148166 """
149167 events = []
150168 loader = _make_loader (monkeypatch , events = events )
@@ -175,13 +193,55 @@ def test_gms_load_branch(monkeypatch, is_rw, expected_events):
175193 backend .move_untracked_params .assert_called_once_with (model )
176194 backend .finalize_write .assert_called_once_with (model )
177195 else :
178- # RO: post_load_weights () must run before the GMS materialize
179- # step so module aliases are wired up before zero-copy mapping.
196+ # RO: setup_aliases () must run before the GMS materialize step so
197+ # module aliases are wired up before zero-copy mapping.
180198 checkpoint_loader .load_weights .assert_not_called ()
181199 loader ._call_load_weights .assert_not_called ()
182200 backend .materialize_module .assert_called_once_with (model )
183201
184202
203+ def test_gms_ro_materializes_between_alias_setup_and_cache_state (monkeypatch ):
204+ events = []
205+ loader = _make_loader (monkeypatch , events = events )
206+ backend = _build_gms_backend (is_rw = False , events = events )
207+ _install_gms_backend (monkeypatch , backend )
208+
209+ checkpoint_loader = MagicMock (name = "checkpoint_loader" )
210+ checkpoint_loader .checkpoint_format = "HF"
211+
212+ def record (event ):
213+ def _append (* _args , ** _kwargs ):
214+ events .append (event )
215+
216+ return _append
217+
218+ checkpoint_loader .post_load_apply .side_effect = record ("post_load_apply" )
219+ checkpoint_loader .post_load_publish .side_effect = record ("post_load_publish" )
220+
221+ # The STRICT pre-materialize identity gate runs between alias setup and
222+ # materialization; record it to pin the ordering without exercising the
223+ # comparison logic, which is covered in test_source_identity.py.
224+ monkeypatch .setattr (
225+ model_loader_mod ,
226+ "check_weight_sharing_compatibility" ,
227+ lambda * _args , ** _kwargs : events .append ("check_source_identity" ),
228+ )
229+
230+ loader .load ("/ckpt" , checkpoint_loader )
231+
232+ assert events == [
233+ "post_load_apply" ,
234+ "setup_aliases" ,
235+ "check_source_identity" ,
236+ "materialize" ,
237+ "cache_derived_state" ,
238+ "post_load_publish" ,
239+ ]
240+ assert "post_load_weights" not in events
241+ checkpoint_loader .load_weights .assert_not_called ()
242+ backend .materialize_module .assert_called_once ()
243+
244+
185245def test_gms_rw_post_load_runs_inside_pool_before_finalize (monkeypatch ):
186246 """Every step that may allocate or rebind tensors must run inside the GMS pool.
187247
0 commit comments