@@ -48,6 +48,12 @@ def load_weights(self, weights, mapper):
4848 def load_draft_weights (self , weights , mapper ):
4949 self ._events .append ("load_draft_weights" )
5050
51+ def setup_aliases (self ):
52+ self ._events .append ("setup_aliases" )
53+
54+ def cache_derived_state (self ):
55+ self ._events .append ("cache_derived_state" )
56+
5157 def post_load_weights (self ):
5258 self ._events .append ("post_load_weights" )
5359
@@ -57,8 +63,15 @@ def _moe_context(config, mapping):
5763 yield None
5864
5965
60- def _make_loader (monkeypatch , * , events , spec_config = None ):
61- llm_args = SimpleNamespace (load_format = LoadFormat .AUTO )
66+ def _make_loader (monkeypatch , * , events , spec_config = None , load_format = LoadFormat .AUTO ):
67+ llm_args = SimpleNamespace (
68+ load_format = load_format ,
69+ gms_config = SimpleNamespace (
70+ socket_path = "/tmp/gms.sock" ,
71+ mode = "test" ,
72+ tag = "test" ,
73+ ),
74+ )
6275 loader = ModelLoader (
6376 llm_args = llm_args ,
6477 mapping = MagicMock (name = "mapping" ),
@@ -77,6 +90,15 @@ def _make_loader(monkeypatch, *, events, spec_config=None):
7790 monkeypatch .setattr (model_loader_mod , "timing" , lambda * _args , ** _kwargs : nullcontext ())
7891 monkeypatch .setattr (model_loader_mod , "maybe_create_moe_load_balancer" , _moe_context )
7992 monkeypatch .setattr (model_loader_mod , "MetaInitMode" , lambda : nullcontext ())
93+ # The MX and GMS load paths build a receiver-side SourceIdentity from the
94+ # resolved ModelConfig. These tests stub the config, so short-circuit the
95+ # fingerprint construction to a sentinel; identity-comparison logic is
96+ # covered separately in test_source_identity.py.
97+ monkeypatch .setattr (
98+ model_loader_mod .SourceIdentity ,
99+ "from_model_config" ,
100+ classmethod (lambda cls , * _args , ** _kwargs : SimpleNamespace (name = "local-identity" )),
101+ )
80102 monkeypatch .setattr (
81103 model_loader_mod .AutoModelForCausalLM ,
82104 "from_config" ,
@@ -114,8 +136,12 @@ def test_mx_success_initializes_mapper_skips_weight_mapping_and_reload_works(mon
114136
115137 # reload() uses self.weight_mapper unconditionally; MX success must
116138 # initialize it even though the initial load skipped _call_load_weights.
139+ model ._weights_transformed = True
140+ model .linear ._weights_transformed = True
117141 loader .reload (model , {"reloaded" : MagicMock ()})
118142 assert loader ._call_load_weights .call_count == 1
143+ assert model ._weights_transformed is False
144+ assert model .linear ._weights_transformed is False
119145 assert events == ["post_load_weights" , "load_weights" ]
120146
121147
@@ -159,6 +185,62 @@ def test_mx_fallback_runs_standard_weight_mapping(monkeypatch):
159185 )
160186
161187
188+ def test_gms_ro_materializes_between_alias_setup_and_cache_state (monkeypatch ):
189+ events = []
190+ loader = _make_loader (monkeypatch , events = events , load_format = LoadFormat .GMS )
191+ checkpoint_loader = MagicMock (name = "checkpoint_loader" )
192+ checkpoint_loader .checkpoint_format = "GMS"
193+
194+ def record (event ):
195+ def _append (* _args , ** _kwargs ):
196+ events .append (event )
197+
198+ return _append
199+
200+ checkpoint_loader .post_load_apply .side_effect = record ("post_load_apply" )
201+ checkpoint_loader .post_load_publish .side_effect = record ("post_load_publish" )
202+
203+ # The STRICT pre-materialize identity gate runs between alias setup and
204+ # materialization; record it to pin the ordering, without exercising the
205+ # comparison logic (covered in test_source_identity.py).
206+ monkeypatch .setattr (
207+ model_loader_mod ,
208+ "check_weight_sharing_compatibility" ,
209+ lambda * _args , ** _kwargs : events .append ("check_source_identity" ),
210+ )
211+
212+ class _GmsBackend :
213+ def __init__ (self , * args , ** kwargs ):
214+ self .is_rw = False
215+
216+ def connect (self ):
217+ return True
218+
219+ def get_source_identity (self ):
220+ return SimpleNamespace (name = "remote-identity" )
221+
222+ def materialize_module (self , model ):
223+ events .append ("materialize_module" )
224+
225+ def cleanup (self ):
226+ events .append ("cleanup" )
227+
228+ monkeypatch .setattr ("tensorrt_llm._torch.memory.GMSBackend" , _GmsBackend )
229+
230+ loader .load ("/ckpt" , checkpoint_loader )
231+
232+ assert events == [
233+ "post_load_apply" ,
234+ "setup_aliases" ,
235+ "check_source_identity" ,
236+ "materialize_module" ,
237+ "cache_derived_state" ,
238+ "post_load_publish" ,
239+ ]
240+ assert "post_load_weights" not in events
241+ checkpoint_loader .load_weights .assert_not_called ()
242+
243+
162244class _HookRecorder (nn .Module ):
163245 def __init__ (
164246 self ,
0 commit comments