Skip to content

Commit 77eddfa

Browse files
committed
[TRTLLM-13250][fix] Avoid GMS MX double transforms
Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
1 parent eeceee3 commit 77eddfa

2 files changed

Lines changed: 76 additions & 6 deletions

File tree

tensorrt_llm/_torch/pyexecutor/model_loader.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -692,12 +692,16 @@ def init_meta_tensor_in_pool(t: torch.Tensor):
692692
checkpoint_loader.post_load_apply(
693693
model, weights_preloaded=weights_preloaded)
694694

695-
for module in model.modules():
696-
if hasattr(
697-
module,
698-
'post_load_weights') and not getattr(
699-
module, '_weights_removed', False):
700-
module.post_load_weights()
695+
mx_staged_receiver_path = self._should_run_mx_staged_receiver_path(
696+
checkpoint_loader,
697+
model,
698+
weights_preloaded=weights_preloaded)
699+
if mx_staged_receiver_path:
700+
self._setup_aliases(model)
701+
self._mark_weights_transformed(model)
702+
self._walk_cache_state(model)
703+
else:
704+
self._walk_full_post_load(model)
701705

702706
# Defensive last-mile sweep: catches strays from
703707
# C++ ops that bypassed the active torch

tests/unittest/_torch/pyexecutor/test_model_loader_gms.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
class _TinyModel(nn.Module):
3030
def __init__(self, events, *, include_draft=False):
3131
super().__init__()
32+
self._weights_transformed = False
3233
self._events = events
3334
if include_draft:
3435
self.draft_model = nn.Module()
@@ -131,6 +132,19 @@ def _install_gms_backend(monkeypatch, backend):
131132
monkeypatch.setattr(memory_mod, "GMSBackend", MagicMock(return_value=backend))
132133

133134

135+
class _PostTransformMxLoader:
136+
checkpoint_format = "MX"
137+
138+
def __init__(self) -> None:
139+
self.load_weights = MagicMock(return_value={})
140+
self.is_weights_preloaded = MagicMock(return_value=True)
141+
self.post_load_apply = MagicMock()
142+
self.post_load_publish = MagicMock()
143+
144+
def is_post_transform_weights_preloaded(self) -> bool:
145+
return True
146+
147+
134148
def _spec_config_needing_draft_weights():
135149
return SimpleNamespace(
136150
spec_dec_mode=SimpleNamespace(need_load_draft_weights=lambda: True),
@@ -347,6 +361,58 @@ def test_gms_rw_loader_preload_skips_mapping_pipeline(monkeypatch):
347361
backend.finalize_write.assert_called_once_with(model)
348362

349363

364+
def test_gms_rw_mx_post_transform_preload_uses_staged_path(monkeypatch):
365+
"""GMS writers that receive post-transform MX bytes must not transform again."""
366+
events = []
367+
loader = _make_loader(monkeypatch, events=events)
368+
monkeypatch.setattr(
369+
ModelLoader,
370+
"_MX_STAGED_RECEIVER_ALLOWLIST",
371+
frozenset({(_TinyModel, ModelLoader._MX_STAGED_RECEIVER_TRANSFORM_PROTOCOL_VERSION)}),
372+
)
373+
backend = _build_gms_backend(is_rw=True, events=events)
374+
backend.move_untracked_params.side_effect = lambda _model: events.append(
375+
"move_untracked_params"
376+
)
377+
backend.finalize_write.side_effect = lambda _model: events.append("finalize_write")
378+
_install_gms_backend(monkeypatch, backend)
379+
380+
checkpoint_loader = _PostTransformMxLoader()
381+
checkpoint_loader.post_load_apply.side_effect = lambda *_a, **_kw: events.append(
382+
"post_load_apply"
383+
)
384+
checkpoint_loader.post_load_publish.side_effect = lambda *_a, **_kw: events.append(
385+
"post_load_publish"
386+
)
387+
388+
model, _ = loader.load("/ckpt", checkpoint_loader)
389+
390+
assert events == [
391+
"pool_enter",
392+
"_apply",
393+
"to",
394+
"post_load_apply",
395+
"setup_aliases",
396+
"cache_derived_state",
397+
"move_untracked_params",
398+
"post_load_publish",
399+
"pool_exit",
400+
"finalize_write",
401+
]
402+
assert "post_load_weights" not in events
403+
assert model._weights_transformed is True
404+
_args, kwargs = checkpoint_loader.load_weights.call_args
405+
assert kwargs["allow_post_transform_weights"] is True
406+
loader._call_load_weights.assert_not_called()
407+
checkpoint_loader.post_load_publish.assert_called_once_with(
408+
model,
409+
checkpoint_dir="/ckpt",
410+
weights_preloaded=True,
411+
source_identity=loader._source_identity,
412+
)
413+
backend.finalize_write.assert_called_once_with(model)
414+
415+
350416
def test_gms_rw_no_load_and_no_preload_raises(monkeypatch):
351417
"""RW + empty ``weights`` + ``is_weights_preloaded()=False`` is a bug.
352418

0 commit comments

Comments
 (0)