@@ -48,6 +48,12 @@ def load_weights(self, weights, mapper):
4848 def load_draft_weights (self , weights , mapper ):
4949 self ._events .append ("load_draft_weights" )
5050
51+ def setup_aliases (self ):
52+ self ._events .append ("setup_aliases" )
53+
54+ def cache_derived_state (self ):
55+ self ._events .append ("cache_derived_state" )
56+
5157 def post_load_weights (self ):
5258 self ._events .append ("post_load_weights" )
5359
@@ -57,8 +63,15 @@ def _moe_context(config, mapping):
5763 yield None
5864
5965
60- def _make_loader (monkeypatch , * , events , spec_config = None ):
61- llm_args = SimpleNamespace (load_format = LoadFormat .AUTO )
66+ def _make_loader (monkeypatch , * , events , spec_config = None , load_format = LoadFormat .AUTO ):
67+ llm_args = SimpleNamespace (
68+ load_format = load_format ,
69+ gms_config = SimpleNamespace (
70+ socket_path = "/tmp/gms.sock" ,
71+ mode = "test" ,
72+ tag = "test" ,
73+ ),
74+ )
6275 loader = ModelLoader (
6376 llm_args = llm_args ,
6477 mapping = MagicMock (name = "mapping" ),
@@ -75,6 +88,15 @@ def _make_loader(monkeypatch, *, events, spec_config=None):
7588 monkeypatch .setattr (model_loader_mod , "timing" , lambda * _args , ** _kwargs : nullcontext ())
7689 monkeypatch .setattr (model_loader_mod , "maybe_create_moe_load_balancer" , _moe_context )
7790 monkeypatch .setattr (model_loader_mod , "MetaInitMode" , lambda : nullcontext ())
91+ # The MX and GMS load paths build a receiver-side SourceIdentity from the
92+ # resolved ModelConfig. These tests stub the config, so short-circuit the
93+ # fingerprint construction to a sentinel; identity-comparison logic is
94+ # covered separately in test_source_identity.py.
95+ monkeypatch .setattr (
96+ model_loader_mod .SourceIdentity ,
97+ "from_model_config" ,
98+ classmethod (lambda cls , * _args , ** _kwargs : SimpleNamespace (name = "local-identity" )),
99+ )
78100 monkeypatch .setattr (
79101 model_loader_mod .AutoModelForCausalLM ,
80102 "from_config" ,
@@ -157,6 +179,62 @@ def test_mx_fallback_runs_standard_weight_mapping(monkeypatch):
157179 )
158180
159181
182+ def test_gms_ro_materializes_between_alias_setup_and_cache_state (monkeypatch ):
183+ events = []
184+ loader = _make_loader (monkeypatch , events = events , load_format = LoadFormat .GMS )
185+ checkpoint_loader = MagicMock (name = "checkpoint_loader" )
186+ checkpoint_loader .checkpoint_format = "GMS"
187+
188+ def record (event ):
189+ def _append (* _args , ** _kwargs ):
190+ events .append (event )
191+
192+ return _append
193+
194+ checkpoint_loader .post_load_apply .side_effect = record ("post_load_apply" )
195+ checkpoint_loader .post_load_publish .side_effect = record ("post_load_publish" )
196+
197+ # The STRICT pre-materialize identity gate runs between alias setup and
198+ # materialization; record it to pin the ordering, without exercising the
199+ # comparison logic (covered in test_source_identity.py).
200+ monkeypatch .setattr (
201+ model_loader_mod ,
202+ "check_weight_sharing_compatibility" ,
203+ lambda * _args , ** _kwargs : events .append ("check_source_identity" ),
204+ )
205+
206+ class _GmsBackend :
207+ def __init__ (self , * args , ** kwargs ):
208+ self .is_rw = False
209+
210+ def connect (self ):
211+ return True
212+
213+ def get_source_identity (self ):
214+ return SimpleNamespace (name = "remote-identity" )
215+
216+ def materialize_module (self , model ):
217+ events .append ("materialize_module" )
218+
219+ def cleanup (self ):
220+ events .append ("cleanup" )
221+
222+ monkeypatch .setattr ("tensorrt_llm._torch.memory.GMSBackend" , _GmsBackend )
223+
224+ loader .load ("/ckpt" , checkpoint_loader )
225+
226+ assert events == [
227+ "post_load_apply" ,
228+ "setup_aliases" ,
229+ "check_source_identity" ,
230+ "materialize_module" ,
231+ "cache_derived_state" ,
232+ "post_load_publish" ,
233+ ]
234+ assert "post_load_weights" not in events
235+ checkpoint_loader .load_weights .assert_not_called ()
236+
237+
160238class _HookRecorder (nn .Module ):
161239 def __init__ (
162240 self ,
0 commit comments