Skip to content

Commit eabb7c0

Browse files
committed
[TRTLLM-13246][feat] Wave 1: migrate aliases to setup_aliases and stage GMS RO load
Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
1 parent b92438a commit eabb7c0

10 files changed

Lines changed: 127 additions & 36 deletions

File tree

tensorrt_llm/_torch/memory/gpu_memory_backend.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,14 @@
3636
CUDA memory pool. After loading, weights are committed for read-only
3737
access by other workers and the client transitions to RO mode in place.
3838
- **RO (Read-Only)**: Subsequent workers zero-copy import already-committed
39-
weights from the GMS pool. `post_load_weights()` must run BEFORE
40-
materialization so that module aliases are set up correctly.
39+
weights from the GMS pool. `setup_aliases()` must run BEFORE
40+
materialization so that module aliases are set up correctly, while derived
41+
state is refreshed after real tensors are bound. RO is validated for models
42+
whose `post_load_weights()` is pure alias wiring; models that additionally
43+
rely on plain Python attributes set inside `post_load_weights()` (rather
44+
than registered `nn.Buffer` / `nn.Parameter` assignments) need to migrate
45+
those side effects to `transform_weights()` or `cache_derived_state()`
46+
before they are safe on the RO path.
4147
"""
4248

4349
from contextlib import contextmanager
@@ -477,7 +483,7 @@ def materialize_module(self, model: nn.Module) -> None:
477483
by GPU pointers from the shared memory region — no data copies,
478484
no disk I/O, just CUDA VMM remapping. The model's submodule
479485
layout must already match the writer's at commit time, including
480-
any aliases / derived buffers introduced by `post_load_weights`.
486+
any aliases introduced by `setup_aliases`.
481487
482488
Args:
483489
model: The `nn.Module` to materialize. Walks the full
@@ -489,11 +495,10 @@ def materialize_module(self, model: nn.Module) -> None:
489495
RuntimeError: If `connect()` has not been called yet.
490496
491497
Note:
492-
`post_load_weights()` must be called on the model BEFORE
493-
this method. The order ensures that any aliases / derived
494-
parameters created by post-load hooks are present on the
495-
module tree at materialization time, so they are bound to
496-
the same GMS storage as their primary tensor.
498+
`setup_aliases()` must be called on the model BEFORE this method.
499+
The order ensures that any structural aliases created by post-load
500+
hooks are present on the module tree at materialization time, so
501+
they are bound to the same GMS storage as their primary tensor.
497502
"""
498503
if self._client is None:
499504
raise RuntimeError("GMS client not connected. Call connect() first.")

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1920,7 +1920,7 @@ def load_weights(self, weights: ConsumableWeightsDict):
19201920
weight_loader = DeepseekV3WeightLoader(self)
19211921
weight_loader.load_weights(weights)
19221922

1923-
def post_load_weights(self):
1923+
def setup_aliases(self):
19241924
for idx, layer in enumerate(
19251925
self.model.layers[:self.config.num_hidden_layers]):
19261926
if idx == self.config.num_hidden_layers - 1:

tensorrt_llm/_torch/models/modeling_exaone_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -725,7 +725,7 @@ def load_weights(
725725
allow_partial_loading=allow_partial_loading,
726726
)
727727

728-
def post_load_weights(self):
728+
def setup_aliases(self):
729729
# For the cross-layer residual+LN fusion.
730730
for idx, layer in enumerate(self.model.layers[: self.config.num_hidden_layers]):
731731
if idx == self.config.num_hidden_layers - 1:

tensorrt_llm/_torch/models/modeling_glm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1074,7 +1074,7 @@ def load_weights(self, weights: ConsumableWeightsDict, allow_partial_loading: bo
10741074
weight_loader = Glm4WeightLoader(self)
10751075
weight_loader.load_weights(weights, allow_partial_loading=allow_partial_loading)
10761076

1077-
def post_load_weights(self):
1077+
def setup_aliases(self):
10781078
for idx, layer in enumerate(self.model.layers[: self.config.num_hidden_layers]):
10791079
if idx == self.config.num_hidden_layers - 1:
10801080
layer.next_layer_layernorm = self.model.norm

tensorrt_llm/_torch/models/modeling_gpt_oss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,7 @@ def load_weights(self, weights: Dict):
631631
else:
632632
self.load_hf_weights(weights)
633633

634-
def post_load_weights(self):
634+
def setup_aliases(self):
635635
for idx, layer in enumerate(
636636
self.model.block[:self.config.num_hidden_layers]):
637637
if idx == 0:

tensorrt_llm/_torch/models/modeling_llama.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ def __init__(
484484
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
485485
eps=config.rms_norm_eps,
486486
dtype=config.torch_dtype)
487-
# When post_load_weights() chains layernorms across layers,
487+
# When setup_aliases() chains layernorms across layers,
488488
# this flag is set to True to skip the input layernorm in
489489
# forward() since it is handled by the previous layer.
490490
self.skip_input_layernorm = False
@@ -709,7 +709,7 @@ def __init__(
709709
quantize_type="nvfp4"
710710
if not self.disable_nvfp4_layernorm_fusion and self.is_nvfp4
711711
and not (differ_pp_stage_with_previous_layer) else None)
712-
# When post_load_weights() chains layernorms across layers,
712+
# When setup_aliases() chains layernorms across layers,
713713
# this flag is set to True to skip the input layernorm in
714714
# forward() since it is handled by the previous layer.
715715
self.skip_input_layernorm = False
@@ -983,7 +983,7 @@ def __init__(self, model_config: ModelConfig[LlamaConfig]):
983983
self.norm = RMSNorm(hidden_size=config.hidden_size,
984984
eps=config.rms_norm_eps,
985985
dtype=config.torch_dtype)
986-
# When post_load_weights() chains the final norm into the
986+
# When setup_aliases() chains the final norm into the
987987
# last decoder layer, this flag is set to True to skip
988988
# applying it again in forward().
989989
self.skip_norm = False
@@ -1088,7 +1088,7 @@ def __init__(self, model_config: ModelConfig[LlamaConfig]):
10881088
self.norm = RMSNorm(hidden_size=config.hidden_size,
10891089
eps=config.rms_norm_eps,
10901090
dtype=config.torch_dtype)
1091-
# When post_load_weights() chains the final norm into the
1091+
# When setup_aliases() chains the final norm into the
10921092
# last decoder layer, this flag is set to True to skip
10931093
# applying it again in forward().
10941094
self.skip_norm = False
@@ -1140,7 +1140,7 @@ def __init__(
11401140
):
11411141
super().__init__(LlamaModel(model_config), model_config)
11421142

1143-
def post_load_weights(self):
1143+
def setup_aliases(self):
11441144
for idx, layer in enumerate(
11451145
self.model.layers[:self.config.num_hidden_layers]):
11461146
if idx == self.config.num_hidden_layers - 1:
@@ -1564,7 +1564,7 @@ def load_weights(self, weights: Dict, weight_mapper: BaseWeightMapper):
15641564
if had_mm_encoder:
15651565
self.mm_encoder = saved_mm_encoder
15661566

1567-
def post_load_weights(self):
1567+
def setup_aliases(self):
15681568
for idx, layer in enumerate(
15691569
self.model.layers[:self.config.num_hidden_layers]):
15701570
if idx == self.config.num_hidden_layers - 1:

tensorrt_llm/_torch/models/modeling_qwen3_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ def __init__(
417417
)
418418
self.preload_weight_modules = self.model.preload_weight_modules
419419

420-
def post_load_weights(self):
420+
def setup_aliases(self):
421421
for idx, layer in enumerate(
422422
self.model.layers[:self.config.num_hidden_layers]):
423423
if idx == self.config.num_hidden_layers - 1:

tensorrt_llm/_torch/models/modeling_qwen3_next.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -980,7 +980,7 @@ def load_weights(self, weights: dict, weight_mapper: BaseWeightMapper):
980980
new_weights = weight_mapper.preprocess_weights(weights)
981981
super().load_weights(new_weights, weight_mapper)
982982

983-
def post_load_weights(self):
983+
def setup_aliases(self):
984984
for idx, layer in enumerate(
985985
self.model.layers[:self.config.num_hidden_layers]):
986986
if idx == self.config.num_hidden_layers - 1:

tensorrt_llm/_torch/pyexecutor/model_loader.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -462,8 +462,9 @@ def init_meta_tensor(t: torch.Tensor):
462462
# post_load_* hooks itself, so the shared post-load block below
463463
# must skip them. RW handles them inside `mem_pool_scope` so the
464464
# committed pool reflects the post-post_load layout; RO runs
465-
# `module.post_load_weights()` before `materialize_module` to
466-
# wire aliases prior to zero-copy mapping.
465+
# ``setup_aliases()`` before ``materialize_module`` to wire aliases
466+
# prior to zero-copy mapping, then refreshes derived state after
467+
# real GMS tensors are bound.
467468
gms_post_load_handled = False
468469
if load_format == LoadFormat.AUTO:
469470
# Pass model= so format-specific loaders (e.g. MX) can
@@ -692,31 +693,33 @@ def init_meta_tensor_in_pool(t: torch.Tensor):
692693
# presharded modules).
693694
#
694695
# Hook order:
695-
# 1. `post_load_apply`: format-specific apply
696+
# 1. ``post_load_apply``: format-specific apply
696697
# work (e.g., MX preshard markers).
697-
# 2. Per-module `post_load_weights`: creates
698-
# aliases/derived parameter attributes BEFORE
699-
# `materialize_module` walks the final module
700-
# tree (including `draft_model` for spec dec).
701-
# 3. `materialize_module`: zero-copy bind GMS
698+
# 2. Top-level ``setup_aliases``: creates structural
699+
# aliases BEFORE ``materialize_module`` walks the
700+
# final module tree (including ``draft_model`` for
701+
# spec dec).
702+
# 3. SourceIdentity gate: STRICT pre-materialize
703+
# compatibility check (GMS has no disk fallback).
704+
# 4. ``materialize_module``: zero-copy bind GMS
702705
# pool storage onto the model parameters.
703-
# 4. `post_load_publish`: any receiver-side
706+
# 5. Per-module ``cache_derived_state``: recompute
707+
# Python-side state from real, materialized
708+
# tensors without re-running one-shot transforms.
709+
# 6. ``post_load_publish``: any receiver-side
704710
# publish (no-op via the receiver guard).
705711
checkpoint_loader.post_load_apply(
706712
model, weights_preloaded=True)
707713

708-
for module in model.modules():
709-
if hasattr(module,
710-
'post_load_weights') and not getattr(
711-
module, '_weights_removed', False):
712-
module.post_load_weights()
714+
self._setup_aliases(model)
713715

714716
# Pre-materialize compatibility gate. GMS has no
715717
# disk-fallback path, so a mismatch raises under STRICT
716718
# rather than falling back.
717719
self._check_gms_source_identity(gms_backend)
718720

719721
gms_backend.materialize_module(model)
722+
self._walk_cache_state(model)
720723

721724
checkpoint_loader.post_load_publish(
722725
model,
@@ -929,6 +932,7 @@ def reload(self,
929932
"Cannot reload weights: weight_mapper was not initialized. "
930933
"This can happen when the initial load used GMS, MX P2P, or "
931934
"VISION_ONLY, which bypass the standard weight mapping path.")
935+
self._reset_weights_transformed(model)
932936
self._call_load_weights(model.load_weights,
933937
weights,
934938
self.weight_mapper,

tests/unittest/_torch/pyexecutor/test_model_loader_mx.py

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,12 @@ def load_weights(self, weights, mapper):
4848
def load_draft_weights(self, weights, mapper):
4949
self._events.append("load_draft_weights")
5050

51+
def setup_aliases(self):
52+
self._events.append("setup_aliases")
53+
54+
def cache_derived_state(self):
55+
self._events.append("cache_derived_state")
56+
5157
def post_load_weights(self):
5258
self._events.append("post_load_weights")
5359

@@ -57,8 +63,15 @@ def _moe_context(config, mapping):
5763
yield None
5864

5965

60-
def _make_loader(monkeypatch, *, events, spec_config=None):
61-
llm_args = SimpleNamespace(load_format=LoadFormat.AUTO)
66+
def _make_loader(monkeypatch, *, events, spec_config=None, load_format=LoadFormat.AUTO):
67+
llm_args = SimpleNamespace(
68+
load_format=load_format,
69+
gms_config=SimpleNamespace(
70+
socket_path="/tmp/gms.sock",
71+
mode="test",
72+
tag="test",
73+
),
74+
)
6275
loader = ModelLoader(
6376
llm_args=llm_args,
6477
mapping=MagicMock(name="mapping"),
@@ -75,6 +88,15 @@ def _make_loader(monkeypatch, *, events, spec_config=None):
7588
monkeypatch.setattr(model_loader_mod, "timing", lambda *_args, **_kwargs: nullcontext())
7689
monkeypatch.setattr(model_loader_mod, "maybe_create_moe_load_balancer", _moe_context)
7790
monkeypatch.setattr(model_loader_mod, "MetaInitMode", lambda: nullcontext())
91+
# The MX and GMS load paths build a receiver-side SourceIdentity from the
92+
# resolved ModelConfig. These tests stub the config, so short-circuit the
93+
# fingerprint construction to a sentinel; identity-comparison logic is
94+
# covered separately in test_source_identity.py.
95+
monkeypatch.setattr(
96+
model_loader_mod.SourceIdentity,
97+
"from_model_config",
98+
classmethod(lambda cls, *_args, **_kwargs: SimpleNamespace(name="local-identity")),
99+
)
78100
monkeypatch.setattr(
79101
model_loader_mod.AutoModelForCausalLM,
80102
"from_config",
@@ -112,8 +134,12 @@ def test_mx_success_initializes_mapper_skips_weight_mapping_and_reload_works(mon
112134

113135
# reload() uses self.weight_mapper unconditionally; MX success must
114136
# initialize it even though the initial load skipped _call_load_weights.
137+
model._weights_transformed = True
138+
model.linear._weights_transformed = True
115139
loader.reload(model, {"reloaded": MagicMock()})
116140
assert loader._call_load_weights.call_count == 1
141+
assert model._weights_transformed is False
142+
assert model.linear._weights_transformed is False
117143
assert events == ["post_load_weights", "load_weights"]
118144

119145

@@ -157,6 +183,62 @@ def test_mx_fallback_runs_standard_weight_mapping(monkeypatch):
157183
)
158184

159185

186+
def test_gms_ro_materializes_between_alias_setup_and_cache_state(monkeypatch):
187+
events = []
188+
loader = _make_loader(monkeypatch, events=events, load_format=LoadFormat.GMS)
189+
checkpoint_loader = MagicMock(name="checkpoint_loader")
190+
checkpoint_loader.checkpoint_format = "GMS"
191+
192+
def record(event):
193+
def _append(*_args, **_kwargs):
194+
events.append(event)
195+
196+
return _append
197+
198+
checkpoint_loader.post_load_apply.side_effect = record("post_load_apply")
199+
checkpoint_loader.post_load_publish.side_effect = record("post_load_publish")
200+
201+
# The STRICT pre-materialize identity gate runs between alias setup and
202+
# materialization; record it to pin the ordering, without exercising the
203+
# comparison logic (covered in test_source_identity.py).
204+
monkeypatch.setattr(
205+
model_loader_mod,
206+
"check_weight_sharing_compatibility",
207+
lambda *_args, **_kwargs: events.append("check_source_identity"),
208+
)
209+
210+
class _GmsBackend:
211+
def __init__(self, *args, **kwargs):
212+
self.is_rw = False
213+
214+
def connect(self):
215+
return True
216+
217+
def get_source_identity(self):
218+
return SimpleNamespace(name="remote-identity")
219+
220+
def materialize_module(self, model):
221+
events.append("materialize_module")
222+
223+
def cleanup(self):
224+
events.append("cleanup")
225+
226+
monkeypatch.setattr("tensorrt_llm._torch.memory.GMSBackend", _GmsBackend)
227+
228+
loader.load("/ckpt", checkpoint_loader)
229+
230+
assert events == [
231+
"post_load_apply",
232+
"setup_aliases",
233+
"check_source_identity",
234+
"materialize_module",
235+
"cache_derived_state",
236+
"post_load_publish",
237+
]
238+
assert "post_load_weights" not in events
239+
checkpoint_loader.load_weights.assert_not_called()
240+
241+
160242
class _HookRecorder(nn.Module):
161243
def __init__(
162244
self,

0 commit comments

Comments
 (0)