@@ -48,6 +48,12 @@ def load_weights(self, weights, mapper):
4848 def load_draft_weights (self , weights , mapper ):
4949 self ._events .append ("load_draft_weights" )
5050
51+ def setup_aliases (self ):
52+ self ._events .append ("setup_aliases" )
53+
54+ def cache_derived_state (self ):
55+ self ._events .append ("cache_derived_state" )
56+
5157 def post_load_weights (self ):
5258 self ._events .append ("post_load_weights" )
5359
@@ -57,8 +63,15 @@ def _moe_context(config, mapping):
5763 yield None
5864
5965
60- def _make_loader (monkeypatch , * , events , spec_config = None ):
61- llm_args = SimpleNamespace (load_format = LoadFormat .AUTO )
66+ def _make_loader (monkeypatch , * , events , spec_config = None , load_format = LoadFormat .AUTO ):
67+ llm_args = SimpleNamespace (
68+ load_format = load_format ,
69+ gms_config = SimpleNamespace (
70+ socket_path = "/tmp/gms.sock" ,
71+ mode = "test" ,
72+ tag = "test" ,
73+ ),
74+ )
6275 loader = ModelLoader (
6376 llm_args = llm_args ,
6477 mapping = MagicMock (name = "mapping" ),
@@ -75,6 +88,15 @@ def _make_loader(monkeypatch, *, events, spec_config=None):
7588 monkeypatch .setattr (model_loader_mod , "timing" , lambda * _args , ** _kwargs : nullcontext ())
7689 monkeypatch .setattr (model_loader_mod , "maybe_create_moe_load_balancer" , _moe_context )
7790 monkeypatch .setattr (model_loader_mod , "MetaInitMode" , lambda : nullcontext ())
91+ # The MX and GMS load paths build a receiver-side SourceIdentity from the
92+ # resolved ModelConfig. These tests stub the config, so short-circuit the
93+ # fingerprint construction to a sentinel; identity-comparison logic is
94+ # covered separately in test_source_identity.py.
95+ monkeypatch .setattr (
96+ model_loader_mod .SourceIdentity ,
97+ "from_model_config" ,
98+ classmethod (lambda cls , * _args , ** _kwargs : SimpleNamespace (name = "local-identity" )),
99+ )
78100 monkeypatch .setattr (
79101 model_loader_mod .AutoModelForCausalLM ,
80102 "from_config" ,
@@ -112,8 +134,12 @@ def test_mx_success_initializes_mapper_skips_weight_mapping_and_reload_works(mon
112134
113135 # reload() uses self.weight_mapper unconditionally; MX success must
114136 # initialize it even though the initial load skipped _call_load_weights.
137+ model ._weights_transformed = True
138+ model .linear ._weights_transformed = True
115139 loader .reload (model , {"reloaded" : MagicMock ()})
116140 assert loader ._call_load_weights .call_count == 1
141+ assert model ._weights_transformed is False
142+ assert model .linear ._weights_transformed is False
117143 assert events == ["post_load_weights" , "load_weights" ]
118144
119145
@@ -157,6 +183,62 @@ def test_mx_fallback_runs_standard_weight_mapping(monkeypatch):
157183 )
158184
159185
186+ def test_gms_ro_materializes_between_alias_setup_and_cache_state (monkeypatch ):
187+ events = []
188+ loader = _make_loader (monkeypatch , events = events , load_format = LoadFormat .GMS )
189+ checkpoint_loader = MagicMock (name = "checkpoint_loader" )
190+ checkpoint_loader .checkpoint_format = "GMS"
191+
192+ def record (event ):
193+ def _append (* _args , ** _kwargs ):
194+ events .append (event )
195+
196+ return _append
197+
198+ checkpoint_loader .post_load_apply .side_effect = record ("post_load_apply" )
199+ checkpoint_loader .post_load_publish .side_effect = record ("post_load_publish" )
200+
201+ # The STRICT pre-materialize identity gate runs between alias setup and
202+ # materialization; record it to pin the ordering, without exercising the
203+ # comparison logic (covered in test_source_identity.py).
204+ monkeypatch .setattr (
205+ model_loader_mod ,
206+ "check_weight_sharing_compatibility" ,
207+ lambda * _args , ** _kwargs : events .append ("check_source_identity" ),
208+ )
209+
210+ class _GmsBackend :
211+ def __init__ (self , * args , ** kwargs ):
212+ self .is_rw = False
213+
214+ def connect (self ):
215+ return True
216+
217+ def get_source_identity (self ):
218+ return SimpleNamespace (name = "remote-identity" )
219+
220+ def materialize_module (self , model ):
221+ events .append ("materialize_module" )
222+
223+ def cleanup (self ):
224+ events .append ("cleanup" )
225+
226+ monkeypatch .setattr ("tensorrt_llm._torch.memory.GMSBackend" , _GmsBackend )
227+
228+ loader .load ("/ckpt" , checkpoint_loader )
229+
230+ assert events == [
231+ "post_load_apply" ,
232+ "setup_aliases" ,
233+ "check_source_identity" ,
234+ "materialize_module" ,
235+ "cache_derived_state" ,
236+ "post_load_publish" ,
237+ ]
238+ assert "post_load_weights" not in events
239+ checkpoint_loader .load_weights .assert_not_called ()
240+
241+
160242class _HookRecorder (nn .Module ):
161243 def __init__ (
162244 self ,
0 commit comments