1515from tensorrt_llm ._torch .pyexecutor .model_loader import ModelLoader
1616from tensorrt_llm .llmapi .llm_args import LoadFormat
1717
18+ _SOURCE_IDENTITY = model_loader_mod .SourceIdentity (
19+ format_version = 1 ,
20+ model_fingerprint = "model" ,
21+ quant_fingerprint = "quant" ,
22+ backend_fingerprint = "backend" ,
23+ parallel_fingerprint = "parallel" ,
24+ rank = 0 ,
25+ shard_fingerprint = "shard" ,
26+ )
27+
1828
1929class _TinyModel (nn .Module ):
2030 def __init__ (self , events , * , include_draft = False ):
@@ -71,6 +81,13 @@ def _make_loader(monkeypatch, *, events, spec_config=None):
7181 monkeypatch .setattr (model_loader_mod , "timing" , lambda * _args , ** _kwargs : nullcontext ())
7282 monkeypatch .setattr (model_loader_mod , "maybe_create_moe_load_balancer" , _moe_context )
7383 monkeypatch .setattr (model_loader_mod , "MetaInitMode" , lambda : nullcontext ())
84+ # These tests stub ModelConfig, while SourceIdentity has dedicated
85+ # coverage. Keep this file focused on ModelLoader GMS branch behavior.
86+ monkeypatch .setattr (
87+ model_loader_mod .SourceIdentity ,
88+ "from_model_config" ,
89+ classmethod (lambda cls , * _args , ** _kwargs : _SOURCE_IDENTITY ),
90+ )
7491 monkeypatch .setattr (
7592 model_loader_mod .AutoModelForCausalLM ,
7693 "from_config" ,
@@ -93,6 +110,7 @@ def _build_gms_backend(*, is_rw, events):
93110 if is_rw :
94111 backend .mem_pool_scope .side_effect = lambda _device : _pool_scope (events )
95112 else :
113+ backend .get_source_identity .return_value = _SOURCE_IDENTITY
96114
97115 def _materialize (_model ):
98116 events .append ("materialize" )
@@ -169,7 +187,10 @@ def test_gms_load_branch(monkeypatch, is_rw, expected_events):
169187 # path (see model_loader.py); HF ignores it, MX uses it for direct
170188 # P2P writes when MX+GMS composition eventually lands.
171189 checkpoint_loader .load_weights .assert_called_once_with (
172- "/ckpt" , mapping = loader .mapping , model = model
190+ "/ckpt" ,
191+ mapping = loader .mapping ,
192+ model = model ,
193+ source_identity = loader ._source_identity ,
173194 )
174195 loader ._call_load_weights .assert_called_once ()
175196 backend .move_untracked_params .assert_called_once_with (model )
0 commit comments