@@ -34,7 +34,13 @@ def to(self, *args, **kwargs):
3434 def load_weights (self , weights , mapper ):
3535 self ._events .append ("load_weights" )
3636
37- def post_load_weights (self ):
37+ def setup_aliases (self ) -> None :
38+ self ._events .append ("setup_aliases" )
39+
40+ def cache_derived_state (self ) -> None :
41+ self ._events .append ("cache_derived_state" )
42+
43+ def post_load_weights (self ) -> None :
3844 self ._events .append ("post_load_weights" )
3945
4046
@@ -66,11 +72,21 @@ def _make_loader(monkeypatch, *, events, spec_config=None):
6672 loader ._call_load_weights = MagicMock (
6773 side_effect = lambda fn , weights , mapper , ** kwargs : fn (weights , mapper )
6874 )
69- loader ._load_and_validate_config = MagicMock (return_value = SimpleNamespace (name = "config" ))
75+ loader ._load_and_validate_config = MagicMock (
76+ return_value = SimpleNamespace (name = "config" , mapping = SimpleNamespace ())
77+ )
7078
7179 monkeypatch .setattr (model_loader_mod , "timing" , lambda * _args , ** _kwargs : nullcontext ())
7280 monkeypatch .setattr (model_loader_mod , "maybe_create_moe_load_balancer" , _moe_context )
7381 monkeypatch .setattr (model_loader_mod , "MetaInitMode" , lambda : nullcontext ())
82+ # GMS builds a receiver-side SourceIdentity from the resolved ModelConfig.
83+ # These tests stub the config, so short-circuit fingerprint construction to
84+ # a sentinel; identity-comparison behavior is covered in test_source_identity.py.
85+ monkeypatch .setattr (
86+ model_loader_mod .SourceIdentity ,
87+ "from_model_config" ,
88+ classmethod (lambda cls , * _args , ** _kwargs : SimpleNamespace (name = "local-identity" )),
89+ )
7490 monkeypatch .setattr (
7591 model_loader_mod .AutoModelForCausalLM ,
7692 "from_config" ,
@@ -93,6 +109,7 @@ def _build_gms_backend(*, is_rw, events):
93109 if is_rw :
94110 backend .mem_pool_scope .side_effect = lambda _device : _pool_scope (events )
95111 else :
112+ backend .get_source_identity .return_value = None
96113
97114 def _materialize (_model ):
98115 events .append ("materialize" )
@@ -129,7 +146,7 @@ def _spec_config_needing_draft_weights():
129146 ),
130147 pytest .param (
131148 False ,
132- ["post_load_weights " , "materialize" ],
149+ ["setup_aliases " , "materialize" , "cache_derived_state " ],
133150 id = "ro" ,
134151 ),
135152 ],
@@ -143,8 +160,9 @@ def test_gms_load_branch(monkeypatch, is_rw, expected_events):
143160 (``_apply`` for meta materialization, ``to('cuda')``, weight
144161 load, ``post_load_weights``) inside the pool, then commits via
145162 ``finalize_write`` once the scope exits.
146- ro: the reader runs ``post_load_weights`` to wire module aliases
147- first, then GMS materializes weights via zero-copy mapping.
163+ ro: the reader runs ``setup_aliases`` to wire module aliases, checks
164+ identity compatibility, materializes weights via zero-copy mapping,
165+ then refreshes derived state from real tensors.
148166 """
149167 events = []
150168 loader = _make_loader (monkeypatch , events = events )
@@ -169,19 +187,64 @@ def test_gms_load_branch(monkeypatch, is_rw, expected_events):
169187 # path (see model_loader.py); HF ignores it, MX uses it for direct
170188 # P2P writes when MX+GMS composition eventually lands.
171189 checkpoint_loader .load_weights .assert_called_once_with (
172- "/ckpt" , mapping = loader .mapping , model = model
190+ "/ckpt" ,
191+ mapping = loader .mapping ,
192+ model = model ,
193+ source_identity = loader ._source_identity ,
173194 )
174195 loader ._call_load_weights .assert_called_once ()
175196 backend .move_untracked_params .assert_called_once_with (model )
176197 backend .finalize_write .assert_called_once_with (model )
177198 else :
178- # RO: post_load_weights () must run before the GMS materialize
179- # step so module aliases are wired up before zero-copy mapping.
199+ # RO: setup_aliases () must run before the GMS materialize step so
200+ # module aliases are wired up before zero-copy mapping.
180201 checkpoint_loader .load_weights .assert_not_called ()
181202 loader ._call_load_weights .assert_not_called ()
182203 backend .materialize_module .assert_called_once_with (model )
183204
184205
206+ def test_gms_ro_materializes_between_alias_setup_and_cache_state (monkeypatch ):
207+ events = []
208+ loader = _make_loader (monkeypatch , events = events )
209+ backend = _build_gms_backend (is_rw = False , events = events )
210+ _install_gms_backend (monkeypatch , backend )
211+
212+ checkpoint_loader = MagicMock (name = "checkpoint_loader" )
213+ checkpoint_loader .checkpoint_format = "HF"
214+
215+ def record (event ):
216+ def _append (* _args , ** _kwargs ):
217+ events .append (event )
218+
219+ return _append
220+
221+ checkpoint_loader .post_load_apply .side_effect = record ("post_load_apply" )
222+ checkpoint_loader .post_load_publish .side_effect = record ("post_load_publish" )
223+
224+ # The STRICT pre-materialize identity gate runs between alias setup and
225+ # materialization; record it to pin the ordering without exercising the
226+ # comparison logic, which is covered in test_source_identity.py.
227+ monkeypatch .setattr (
228+ model_loader_mod ,
229+ "check_weight_sharing_compatibility" ,
230+ lambda * _args , ** _kwargs : events .append ("check_source_identity" ),
231+ )
232+
233+ loader .load ("/ckpt" , checkpoint_loader )
234+
235+ assert events == [
236+ "post_load_apply" ,
237+ "setup_aliases" ,
238+ "check_source_identity" ,
239+ "materialize" ,
240+ "cache_derived_state" ,
241+ "post_load_publish" ,
242+ ]
243+ assert "post_load_weights" not in events
244+ checkpoint_loader .load_weights .assert_not_called ()
245+ backend .materialize_module .assert_called_once ()
246+
247+
185248def test_gms_rw_post_load_runs_inside_pool_before_finalize (monkeypatch ):
186249 """Every step that may allocate or rebind tensors must run inside the GMS pool.
187250
0 commit comments