Skip to content

Commit 431b816

Browse files
committed
[TRTLLM-13248][feat] Wave 4 add MX staged receiver cutover
Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
1 parent 9834935 commit 431b816

5 files changed

Lines changed: 273 additions & 4 deletions

File tree

tensorrt_llm/_torch/models/checkpoints/base_checkpoint_loader.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
14
from abc import ABC, abstractmethod
25
from typing import Any
36

@@ -69,6 +72,17 @@ def is_weights_preloaded(self) -> bool:
6972
"""Whether the last load wrote weights directly into the model."""
7073
return False
7174

75+
def is_post_transform_weights_preloaded(self) -> bool:
76+
"""Whether the last direct preload delivered post-transform weights.
77+
78+
This is narrower than :meth:`is_weights_preloaded`: a loader may write
79+
bytes directly into the model while those bytes are still the raw
80+
checkpoint layout. Only return ``True`` when the source identity was
81+
verified and the incoming bytes can safely skip module
82+
``transform_weights()`` hooks.
83+
"""
84+
return False
85+
7286
def post_load_apply(self,
7387
model: nn.Module,
7488
*,

tensorrt_llm/_torch/models/checkpoints/mx/checkpoint_loader.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,8 @@ def __init__(
129129
self._model_name = str(model_name) if model_name is not None else None
130130
self._query_timeout_s = query_timeout_s
131131
self._p2p_succeeded = False
132+
self._post_transform_weights_preloaded = False
133+
self._source_identity_compatible_for_last_load = False
132134
# Receiver's local SourceIdentity, supplied per load_weights() call by
133135
# ModelLoader; the authority for the pre-transfer compatibility gate.
134136
self._local_source_identity: Optional[SourceIdentity] = None
@@ -184,6 +186,21 @@ def is_weights_preloaded(self) -> bool:
184186
"""
185187
return self._p2p_succeeded
186188

189+
def is_post_transform_weights_preloaded(self) -> bool:
190+
"""Whether the last successful MX preload delivered transformed bytes.
191+
192+
Wave 4 wires the receiver-side staged hook path, but the MX publisher
193+
still emits raw pre-transform bytes. Keep this false until Wave 5 wires
194+
explicit MX metadata for post-transform publication. The source
195+
identity bit is included here so callers have one conservative signal:
196+
no identity match, no transform skip.
197+
"""
198+
return (
199+
self._p2p_succeeded
200+
and self._post_transform_weights_preloaded
201+
and self._source_identity_compatible_for_last_load
202+
)
203+
187204
def load_weights(self, checkpoint_dir: str, mapping: Mapping, **kwargs) -> dict[str, Any]:
188205
"""Load weights, preferring MX P2P transfer when available.
189206
@@ -207,6 +224,8 @@ def load_weights(self, checkpoint_dir: str, mapping: Mapping, **kwargs) -> dict[
207224
# Popped here so it never leaks into the disk-fallback signature.
208225
self._local_source_identity = kwargs.pop("source_identity", None)
209226
self._p2p_succeeded = False
227+
self._post_transform_weights_preloaded = False
228+
self._source_identity_compatible_for_last_load = False
210229

211230
if self._mx_server_url is None or model is None:
212231
return self._fallback_to_disk(
@@ -237,14 +256,21 @@ def load_weights(self, checkpoint_dir: str, mapping: Mapping, **kwargs) -> dict[
237256

238257
# Pre-transfer compatibility gate: on mismatch, skip the transfer
239258
# before any RDMA work starts and fall back to disk.
240-
if not self._source_identity_compatible(checkpoint_dir, MxClient, _build_trtllm_identity):
259+
self._source_identity_compatible_for_last_load = self._source_identity_compatible(
260+
checkpoint_dir, MxClient, _build_trtllm_identity
261+
)
262+
if not self._source_identity_compatible_for_last_load:
241263
return self._fallback_to_disk(
242264
checkpoint_dir,
243265
mapping,
244266
reason="source SourceIdentity incompatible with receiver",
245267
**kwargs,
246268
)
247269

270+
self._post_transform_weights_preloaded = self._source_metadata_is_post_transform(
271+
checkpoint_dir, MxClient, _build_trtllm_identity
272+
)
273+
248274
timeout_override = self._resolve_query_timeout_override(
249275
checkpoint_dir,
250276
MxClient,
@@ -271,6 +297,24 @@ def load_weights(self, checkpoint_dir: str, mapping: Mapping, **kwargs) -> dict[
271297
fallback_bytes = sum(
272298
tensor.numel() * tensor.element_size() for tensor in fallback_weights.values()
273299
)
300+
if self._post_transform_weights_preloaded:
301+
self._post_transform_weights_preloaded = False
302+
self._source_identity_compatible_for_last_load = False
303+
logger.warning(
304+
"MX P2P returned %d fallback weights (%.2f MiB, size mismatch) "
305+
"from a post-transform source at %s. Falling back to a full "
306+
"disk load to avoid mixing transformed P2P tensors with raw "
307+
"fallback tensors before the full post-load transform path.",
308+
len(fallback_weights),
309+
fallback_bytes / (1 << 20),
310+
self._mx_server_url,
311+
)
312+
return self._fallback_to_disk(
313+
checkpoint_dir,
314+
mapping,
315+
reason="post-transform source returned partial fallback weights",
316+
**kwargs,
317+
)
274318
# Mixed-success case: MX delivered matched tensors into model
275319
# params via P2P and returned only size-mismatched tensors for
276320
# the standard disk path to apply. Keep the P2P transfer and
@@ -285,6 +329,7 @@ def load_weights(self, checkpoint_dir: str, mapping: Mapping, **kwargs) -> dict[
285329
self._mx_server_url,
286330
)
287331
self._p2p_succeeded = True
332+
self._post_transform_weights_preloaded = False
288333
return fallback_weights
289334

290335
self._p2p_succeeded = True
@@ -386,6 +431,18 @@ def _fetch_source_identity(
386431
# exposes a field for it. This is the single seam the gate depends on.
387432
return None
388433

434+
def _source_metadata_is_post_transform(
435+
self, _checkpoint_dir: str, _mx_client_type: Type[Any], _build_identity: Callable[..., Any]
436+
) -> bool:
437+
"""Whether the selected MX source publishes post-transform bytes.
438+
439+
Wave 4 keeps production behavior dormant: Modelexpress metadata does
440+
not yet expose raw-vs-transformed layout state, and the TRT-LLM MX
441+
publisher still publishes before module transforms. Wave 5 should wire
442+
this seam to explicit source metadata when flipping the publisher.
443+
"""
444+
return False
445+
389446
def _resolve_publish_name(self, checkpoint_dir: Optional[str]) -> str:
390447
return _resolve_mx_model_name(self._model_name, checkpoint_dir)
391448

tensorrt_llm/_torch/pyexecutor/model_loader.py

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
14
import copy
25
import inspect
36
import os
@@ -264,6 +267,8 @@ class ModelLoader:
264267
Handles the loading, configuration, and weight initialization of a PyTorch model.
265268
This class isolates model loading logic from the main execution engine.
266269
"""
270+
_MX_STAGED_RECEIVER_TRANSFORM_PROTOCOL_VERSION = 1
271+
_MX_STAGED_RECEIVER_ALLOWLIST = frozenset()
267272

268273
def __init__(self,
269274
llm_args: TorchLlmArgs,
@@ -782,12 +787,20 @@ def init_meta_tensor_in_pool(t: torch.Tensor):
782787
if not gms_post_load_handled:
783788
checkpoint_loader.post_load_apply(
784789
model, weights_preloaded=weights_preloaded)
790+
mx_staged_receiver_path = self._should_run_mx_staged_receiver_path(
791+
checkpoint_loader,
792+
model,
793+
weights_preloaded=weights_preloaded)
794+
if mx_staged_receiver_path:
795+
self._setup_aliases(model)
796+
self._mark_weights_transformed(model)
797+
self._walk_cache_state(model)
785798
checkpoint_loader.post_load_publish(
786799
model,
787800
checkpoint_dir=checkpoint_dir,
788801
weights_preloaded=weights_preloaded)
789-
790-
self._walk_full_post_load(model)
802+
if not mx_staged_receiver_path:
803+
self._walk_full_post_load(model)
791804

792805
# TODO(GMS-MOE-LB): when the (MoE, GMS) combination is enabled,
793806
# `register_weight_slots_after_to_cuda` and `finalize_model`
@@ -830,6 +843,67 @@ def _check_gms_source_identity(self, gms_backend) -> None:
830843
IdentityCheckPolicy.STRICT,
831844
)
832845

846+
@classmethod
847+
def _should_run_mx_staged_receiver_path(
848+
cls, checkpoint_loader: BaseCheckpointLoader,
849+
model: DecoderModelForCausalLM, *, weights_preloaded: bool) -> bool:
850+
"""Whether an MX receiver can skip one-shot weight transforms.
851+
852+
The Wave 4 path is intentionally dormant for production: the allow-list
853+
is empty, and MX still reports raw pre-transform bytes. Tests can opt in
854+
a synthetic model by patching the allow-list and checkpoint-loader
855+
signal, proving the staged receiver branch without enabling real models.
856+
"""
857+
if checkpoint_loader.checkpoint_format != "MX" or not weights_preloaded:
858+
return False
859+
860+
method = getattr(type(checkpoint_loader),
861+
'is_post_transform_weights_preloaded', None)
862+
if method is None or not checkpoint_loader.is_post_transform_weights_preloaded(
863+
):
864+
return False
865+
866+
allowlist_key = (
867+
type(model),
868+
cls._MX_STAGED_RECEIVER_TRANSFORM_PROTOCOL_VERSION,
869+
)
870+
if allowlist_key in cls._MX_STAGED_RECEIVER_ALLOWLIST:
871+
logger.info(
872+
"MX receiver using staged post-load path for %s "
873+
"(transform protocol v%d).",
874+
type(model).__name__,
875+
cls._MX_STAGED_RECEIVER_TRANSFORM_PROTOCOL_VERSION,
876+
)
877+
return True
878+
879+
# WAVE 5 NOTE: once MX can publish real post-transform bytes, this
880+
# fallthrough must not run the full post_load_weights() path on those
881+
# bytes for a non-allow-listed model. Wave 5 should either fail/fallback
882+
# before accepting the transfer or allow-list the model after validation.
883+
logger.info(
884+
"MX receiver got post-transform weights for %s, but the model is "
885+
"not allow-listed for staged post-load transform protocol v%d. "
886+
"Running the full post-load path.",
887+
type(model).__name__,
888+
cls._MX_STAGED_RECEIVER_TRANSFORM_PROTOCOL_VERSION,
889+
)
890+
return False
891+
892+
@staticmethod
893+
def _mark_weights_transformed(model: DecoderModelForCausalLM) -> None:
894+
"""Mark modules with transform guards as already transformed.
895+
896+
Post-transform sharing paths skip ``transform_weights()`` because the
897+
incoming bytes already use the final runtime layout. Preserve that
898+
lifecycle state on modules that participate in the transform guard
899+
protocol so a later orchestrator/refactor does not treat them as raw
900+
checkpoint bytes.
901+
"""
902+
for module in model.modules():
903+
if hasattr(module, '_weights_transformed') and not getattr(
904+
module, '_weights_removed', False):
905+
module._weights_transformed = True
906+
833907
@staticmethod
834908
def _setup_aliases(model: DecoderModelForCausalLM) -> None:
835909
"""Run structural alias setup on eligible modules.

tests/unittest/_torch/models/checkpoints/mx/test_mx_checkpoint_loader.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,13 @@ def test_no_args_constructs(self):
4444
loader = MXCheckpointLoader()
4545
assert loader.mx_server_url is None
4646
assert loader.is_weights_preloaded() is False
47+
assert loader.is_post_transform_weights_preloaded() is False
4748

4849
def test_mx_server_url_stored(self):
4950
loader = MXCheckpointLoader(mx_server_url="http://mx:8001")
5051
assert loader.mx_server_url == "http://mx:8001"
5152
assert loader.is_weights_preloaded() is False
53+
assert loader.is_post_transform_weights_preloaded() is False
5254

5355
def test_query_timeout_stored(self):
5456
loader = MXCheckpointLoader(mx_server_url="http://mx:8001", query_timeout_s=900)
@@ -75,6 +77,17 @@ def test_checkpoint_format_backing_attr(self):
7577
def test_is_weights_preloaded_initial(self):
7678
loader = MXCheckpointLoader()
7779
assert loader.is_weights_preloaded() is False
80+
assert loader.is_post_transform_weights_preloaded() is False
81+
82+
def test_post_transform_signal_requires_p2p_and_identity_match(self):
83+
loader = MXCheckpointLoader()
84+
loader._p2p_succeeded = True
85+
loader._post_transform_weights_preloaded = True
86+
loader._source_identity_compatible_for_last_load = False
87+
assert loader.is_post_transform_weights_preloaded() is False
88+
89+
loader._source_identity_compatible_for_last_load = True
90+
assert loader.is_post_transform_weights_preloaded() is True
7891

7992

8093
# ---------------------------------------------------------------------------
@@ -145,9 +158,11 @@ def _modelexpress_unavailable(stack):
145158

146159
@staticmethod
147160
def _upstream_raises(stack):
161+
loader = MXCheckpointLoader(mx_server_url="http://mx:8001")
162+
loader._source_identity_compatible = MagicMock(return_value=True)
148163
fake_mx = _build_fake_modelexpress(load_weights_side_effect=RuntimeError("boom"))
149164
stack.enter_context(_install_fake_modelexpress(fake_mx))
150-
return (MXCheckpointLoader(mx_server_url="http://mx:8001"), {"model": MagicMock()})
165+
return (loader, {"model": MagicMock()})
151166

152167
@pytest.mark.parametrize(
153168
"trigger_id, setup",
@@ -196,6 +211,7 @@ def test_p2p_full_success_returns_empty_dict(self):
196211
# ``is_weights_preloaded()`` signal as "skip the standard
197212
# weight-mapping pipeline".
198213
loader = MXCheckpointLoader(mx_server_url="http://mx:8001")
214+
loader._source_identity_compatible = MagicMock(return_value=True)
199215
fake_mx = _build_fake_modelexpress(load_weights_return={})
200216
mapping = MagicMock(name="mapping")
201217
model = MagicMock(name="model")
@@ -205,6 +221,7 @@ def test_p2p_full_success_returns_empty_dict(self):
205221

206222
assert result == {}
207223
assert loader.is_weights_preloaded() is True
224+
assert loader.is_post_transform_weights_preloaded() is False
208225

209226
# Verify the integration contract with the upstream library:
210227
# 1. Constructed MxLiveWeightLoader with our mx_server_url.
@@ -222,6 +239,7 @@ def test_mixed_success_returns_fallback_weights(self):
222239
# tensors), keep the P2P transfer and let ModelLoader merge these
223240
# tensors through the standard disk pipeline.
224241
loader = MXCheckpointLoader(mx_server_url="http://mx:8001")
242+
loader._source_identity_compatible = MagicMock(return_value=True)
225243
fallback = {"some.weight": MagicMock()}
226244
fake_mx = _build_fake_modelexpress(load_weights_return=fallback)
227245

@@ -232,9 +250,36 @@ def test_mixed_success_returns_fallback_weights(self):
232250
result = loader.load_weights("/nonexistent", mapping=MagicMock(), model=MagicMock())
233251

234252
assert loader.is_weights_preloaded() is True
253+
assert loader.is_post_transform_weights_preloaded() is False
235254
assert result is fallback
236255
mock_super_load.assert_not_called()
237256

257+
def test_post_transform_mixed_success_falls_back_to_full_disk_load(self):
258+
# Wave 5 will let MX advertise post-transform sources. If such a
259+
# source only partially succeeds, merging raw fallback tensors would
260+
# force ModelLoader onto the full post-load path and double-transform
261+
# the P2P subset. Lock the safer behavior now: abandon the partial
262+
# post-transform transfer and return a full disk load instead.
263+
loader = MXCheckpointLoader(mx_server_url="http://mx:8001")
264+
loader._source_identity_compatible = MagicMock(return_value=True)
265+
loader._source_metadata_is_post_transform = MagicMock(return_value=True)
266+
fallback = {"some.weight": MagicMock(numel=lambda: 1, element_size=lambda: 4)}
267+
disk_weights = {"disk.weight": MagicMock()}
268+
fake_mx = _build_fake_modelexpress(load_weights_return=fallback)
269+
270+
with (
271+
_install_fake_modelexpress(fake_mx),
272+
patch.object(
273+
HfCheckpointLoader, "load_weights", return_value=disk_weights
274+
) as mock_super_load,
275+
):
276+
result = loader.load_weights("/nonexistent", mapping=MagicMock(), model=MagicMock())
277+
278+
assert result is disk_weights
279+
assert loader.is_weights_preloaded() is False
280+
assert loader.is_post_transform_weights_preloaded() is False
281+
mock_super_load.assert_called_once()
282+
238283

239284
# ---------------------------------------------------------------------------
240285
# publish_as_source — env-var dance and graceful no-op
@@ -428,6 +473,7 @@ def _assert_timeout(*args, **kwargs):
428473
return {}
429474

430475
loader = MXCheckpointLoader(mx_server_url="http://mx:8001")
476+
loader._source_identity_compatible = MagicMock(return_value=True)
431477
fake_mx = _build_fake_modelexpress(load_weights_side_effect=_assert_timeout)
432478
with _install_fake_modelexpress(fake_mx):
433479
loader.load_weights("/nonexistent", mapping=MagicMock(), model=MagicMock())
@@ -439,6 +485,7 @@ def _assert_no_timeout(*args, **kwargs):
439485
return {}
440486

441487
loader = MXCheckpointLoader(mx_server_url="http://mx:8001")
488+
loader._source_identity_compatible = MagicMock(return_value=True)
442489
fake_mx = _build_fake_modelexpress(
443490
load_weights_side_effect=_assert_no_timeout,
444491
source_instances=[MagicMock()],
@@ -457,6 +504,7 @@ def _assert_env_timeout(*args, **kwargs):
457504
return {}
458505

459506
loader = MXCheckpointLoader(mx_server_url="http://mx:8001")
507+
loader._source_identity_compatible = MagicMock(return_value=True)
460508
fake_mx = _build_fake_modelexpress(load_weights_side_effect=_assert_env_timeout)
461509
with _install_fake_modelexpress(fake_mx):
462510
loader.load_weights("/nonexistent", mapping=MagicMock(), model=MagicMock())
@@ -468,6 +516,7 @@ def _assert_config_timeout(*args, **kwargs):
468516
return {}
469517

470518
loader = MXCheckpointLoader(mx_server_url="http://mx:8001", query_timeout_s=900)
519+
loader._source_identity_compatible = MagicMock(return_value=True)
471520
fake_mx = _build_fake_modelexpress(load_weights_side_effect=_assert_config_timeout)
472521
with _install_fake_modelexpress(fake_mx):
473522
loader.load_weights("/nonexistent", mapping=MagicMock(), model=MagicMock())

0 commit comments

Comments
 (0)