Skip to content

Commit 9ed7ce4

Browse files
[https://nvbugs/6337235][test] Fix MX/GMS model loader fixtures (#15471)
Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
1 parent f536026 commit 9ed7ce4

2 files changed

Lines changed: 40 additions & 1 deletion

File tree

tests/unittest/_torch/pyexecutor/test_model_loader_gms.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,16 @@
1515
from tensorrt_llm._torch.pyexecutor.model_loader import ModelLoader
1616
from tensorrt_llm.llmapi.llm_args import LoadFormat
1717

18+
_SOURCE_IDENTITY = model_loader_mod.SourceIdentity(
19+
format_version=1,
20+
model_fingerprint="model",
21+
quant_fingerprint="quant",
22+
backend_fingerprint="backend",
23+
parallel_fingerprint="parallel",
24+
rank=0,
25+
shard_fingerprint="shard",
26+
)
27+
1828

1929
class _TinyModel(nn.Module):
2030
def __init__(self, events, *, include_draft=False):
@@ -71,6 +81,13 @@ def _make_loader(monkeypatch, *, events, spec_config=None):
7181
monkeypatch.setattr(model_loader_mod, "timing", lambda *_args, **_kwargs: nullcontext())
7282
monkeypatch.setattr(model_loader_mod, "maybe_create_moe_load_balancer", _moe_context)
7383
monkeypatch.setattr(model_loader_mod, "MetaInitMode", lambda: nullcontext())
84+
# These tests stub ModelConfig, while SourceIdentity has dedicated
85+
# coverage. Keep this file focused on ModelLoader GMS branch behavior.
86+
monkeypatch.setattr(
87+
model_loader_mod.SourceIdentity,
88+
"from_model_config",
89+
classmethod(lambda cls, *_args, **_kwargs: _SOURCE_IDENTITY),
90+
)
7491
monkeypatch.setattr(
7592
model_loader_mod.AutoModelForCausalLM,
7693
"from_config",
@@ -93,6 +110,7 @@ def _build_gms_backend(*, is_rw, events):
93110
if is_rw:
94111
backend.mem_pool_scope.side_effect = lambda _device: _pool_scope(events)
95112
else:
113+
backend.get_source_identity.return_value = _SOURCE_IDENTITY
96114

97115
def _materialize(_model):
98116
events.append("materialize")
@@ -169,7 +187,10 @@ def test_gms_load_branch(monkeypatch, is_rw, expected_events):
169187
# path (see model_loader.py); HF ignores it, MX uses it for direct
170188
# P2P writes when MX+GMS composition eventually lands.
171189
checkpoint_loader.load_weights.assert_called_once_with(
172-
"/ckpt", mapping=loader.mapping, model=model
190+
"/ckpt",
191+
mapping=loader.mapping,
192+
model=model,
193+
source_identity=loader._source_identity,
173194
)
174195
loader._call_load_weights.assert_called_once()
175196
backend.move_untracked_params.assert_called_once_with(model)

tests/unittest/_torch/pyexecutor/test_model_loader_mx.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,16 @@
1313
from tensorrt_llm._torch.pyexecutor.model_loader import ModelLoader
1414
from tensorrt_llm.llmapi.llm_args import LoadFormat
1515

16+
_SOURCE_IDENTITY = model_loader_mod.SourceIdentity(
17+
format_version=1,
18+
model_fingerprint="model",
19+
quant_fingerprint="quant",
20+
backend_fingerprint="backend",
21+
parallel_fingerprint="parallel",
22+
rank=0,
23+
shard_fingerprint="shard",
24+
)
25+
1626

1727
class _LinearStub(nn.Module):
1828
def post_load_weights(self):
@@ -77,6 +87,13 @@ def _make_loader(monkeypatch, *, events, spec_config=None):
7787
monkeypatch.setattr(model_loader_mod, "timing", lambda *_args, **_kwargs: nullcontext())
7888
monkeypatch.setattr(model_loader_mod, "maybe_create_moe_load_balancer", _moe_context)
7989
monkeypatch.setattr(model_loader_mod, "MetaInitMode", lambda: nullcontext())
90+
# These tests stub ModelConfig, while SourceIdentity has dedicated
91+
# coverage. Keep this file focused on ModelLoader MX branch behavior.
92+
monkeypatch.setattr(
93+
model_loader_mod.SourceIdentity,
94+
"from_model_config",
95+
classmethod(lambda cls, *_args, **_kwargs: _SOURCE_IDENTITY),
96+
)
8097
monkeypatch.setattr(
8198
model_loader_mod.AutoModelForCausalLM,
8299
"from_config",
@@ -105,6 +122,7 @@ def test_mx_success_initializes_mapper_skips_weight_mapping_and_reload_works(mon
105122
_args, kwargs = checkpoint_loader.load_weights.call_args
106123
assert kwargs["mapping"] is loader.mapping
107124
assert kwargs["model"] is model
125+
assert kwargs["source_identity"] is loader._source_identity
108126
assert loader._call_load_weights.call_count == 0
109127
checkpoint_loader.get_initialized_weight_mapper.assert_called_once()
110128
assert loader.weight_mapper is checkpoint_loader.get_initialized_weight_mapper.return_value

0 commit comments

Comments
 (0)