Skip to content

Commit ac30c0a

Browse files
committed
[TRTLLM-13246][feat] Wave 1: migrate aliases to setup_aliases and stage GMS RO load
Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
1 parent d6cd918 commit ac30c0a

10 files changed

Lines changed: 122 additions & 36 deletions

File tree

tensorrt_llm/_torch/memory/gpu_memory_backend.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,14 @@
3636
CUDA memory pool. After loading, weights are committed for read-only
3737
access by other workers and the client transitions to RO mode in place.
3838
- **RO (Read-Only)**: Subsequent workers zero-copy import already-committed
39-
weights from the GMS pool. `post_load_weights()` must run BEFORE
40-
materialization so that module aliases are set up correctly.
39+
weights from the GMS pool. `setup_aliases()` must run BEFORE
40+
materialization so that module aliases are set up correctly, while derived
41+
state is refreshed after real tensors are bound. RO is validated for models
42+
whose `post_load_weights()` is pure alias wiring; models that additionally
43+
rely on plain Python attributes set inside `post_load_weights()` (rather
44+
than registered `nn.Buffer` / `nn.Parameter` assignments) need to migrate
45+
those side effects to `transform_weights()` or `cache_derived_state()`
46+
before they are safe on the RO path.
4147
"""
4248

4349
from contextlib import contextmanager
@@ -477,7 +483,7 @@ def materialize_module(self, model: nn.Module) -> None:
477483
by GPU pointers from the shared memory region — no data copies,
478484
no disk I/O, just CUDA VMM remapping. The model's submodule
479485
layout must already match the writer's at commit time, including
480-
any aliases / derived buffers introduced by `post_load_weights`.
486+
any aliases introduced by `setup_aliases`.
481487
482488
Args:
483489
model: The `nn.Module` to materialize. Walks the full
@@ -489,11 +495,10 @@ def materialize_module(self, model: nn.Module) -> None:
489495
RuntimeError: If `connect()` has not been called yet.
490496
491497
Note:
492-
`post_load_weights()` must be called on the model BEFORE
493-
this method. The order ensures that any aliases / derived
494-
parameters created by post-load hooks are present on the
495-
module tree at materialization time, so they are bound to
496-
the same GMS storage as their primary tensor.
498+
`setup_aliases()` must be called on the model BEFORE this method.
499+
The order ensures that any structural aliases created by post-load
500+
hooks are present on the module tree at materialization time, so
501+
they are bound to the same GMS storage as their primary tensor.
497502
"""
498503
if self._client is None:
499504
raise RuntimeError("GMS client not connected. Call connect() first.")

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1920,7 +1920,7 @@ def load_weights(self, weights: ConsumableWeightsDict):
19201920
weight_loader = DeepseekV3WeightLoader(self)
19211921
weight_loader.load_weights(weights)
19221922

1923-
def post_load_weights(self):
1923+
def setup_aliases(self):
19241924
for idx, layer in enumerate(
19251925
self.model.layers[:self.config.num_hidden_layers]):
19261926
if idx == self.config.num_hidden_layers - 1:

tensorrt_llm/_torch/models/modeling_exaone_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -710,7 +710,7 @@ def load_weights(
710710
allow_partial_loading=allow_partial_loading,
711711
)
712712

713-
def post_load_weights(self):
713+
def setup_aliases(self):
714714
# For the cross-layer residual+LN fusion.
715715
for idx, layer in enumerate(self.model.layers[: self.config.num_hidden_layers]):
716716
if idx == self.config.num_hidden_layers - 1:

tensorrt_llm/_torch/models/modeling_glm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1074,7 +1074,7 @@ def load_weights(self, weights: ConsumableWeightsDict, allow_partial_loading: bo
10741074
weight_loader = Glm4WeightLoader(self)
10751075
weight_loader.load_weights(weights, allow_partial_loading=allow_partial_loading)
10761076

1077-
def post_load_weights(self):
1077+
def setup_aliases(self):
10781078
for idx, layer in enumerate(self.model.layers[: self.config.num_hidden_layers]):
10791079
if idx == self.config.num_hidden_layers - 1:
10801080
layer.next_layer_layernorm = self.model.norm

tensorrt_llm/_torch/models/modeling_gpt_oss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,7 @@ def load_weights(self, weights: Dict):
631631
else:
632632
self.load_hf_weights(weights)
633633

634-
def post_load_weights(self):
634+
def setup_aliases(self):
635635
for idx, layer in enumerate(
636636
self.model.block[:self.config.num_hidden_layers]):
637637
if idx == 0:

tensorrt_llm/_torch/models/modeling_llama.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ def __init__(
484484
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
485485
eps=config.rms_norm_eps,
486486
dtype=config.torch_dtype)
487-
# When post_load_weights() chains layernorms across layers,
487+
# When setup_aliases() chains layernorms across layers,
488488
# this flag is set to True to skip the input layernorm in
489489
# forward() since it is handled by the previous layer.
490490
self.skip_input_layernorm = False
@@ -709,7 +709,7 @@ def __init__(
709709
quantize_type="nvfp4"
710710
if not self.disable_nvfp4_layernorm_fusion and self.is_nvfp4
711711
and not (differ_pp_stage_with_previous_layer) else None)
712-
# When post_load_weights() chains layernorms across layers,
712+
# When setup_aliases() chains layernorms across layers,
713713
# this flag is set to True to skip the input layernorm in
714714
# forward() since it is handled by the previous layer.
715715
self.skip_input_layernorm = False
@@ -983,7 +983,7 @@ def __init__(self, model_config: ModelConfig[LlamaConfig]):
983983
self.norm = RMSNorm(hidden_size=config.hidden_size,
984984
eps=config.rms_norm_eps,
985985
dtype=config.torch_dtype)
986-
# When post_load_weights() chains the final norm into the
986+
# When setup_aliases() chains the final norm into the
987987
# last decoder layer, this flag is set to True to skip
988988
# applying it again in forward().
989989
self.skip_norm = False
@@ -1088,7 +1088,7 @@ def __init__(self, model_config: ModelConfig[LlamaConfig]):
10881088
self.norm = RMSNorm(hidden_size=config.hidden_size,
10891089
eps=config.rms_norm_eps,
10901090
dtype=config.torch_dtype)
1091-
# When post_load_weights() chains the final norm into the
1091+
# When setup_aliases() chains the final norm into the
10921092
# last decoder layer, this flag is set to True to skip
10931093
# applying it again in forward().
10941094
self.skip_norm = False
@@ -1140,7 +1140,7 @@ def __init__(
11401140
):
11411141
super().__init__(LlamaModel(model_config), model_config)
11421142

1143-
def post_load_weights(self):
1143+
def setup_aliases(self):
11441144
for idx, layer in enumerate(
11451145
self.model.layers[:self.config.num_hidden_layers]):
11461146
if idx == self.config.num_hidden_layers - 1:
@@ -1564,7 +1564,7 @@ def load_weights(self, weights: Dict, weight_mapper: BaseWeightMapper):
15641564
if had_mm_encoder:
15651565
self.mm_encoder = saved_mm_encoder
15661566

1567-
def post_load_weights(self):
1567+
def setup_aliases(self):
15681568
for idx, layer in enumerate(
15691569
self.model.layers[:self.config.num_hidden_layers]):
15701570
if idx == self.config.num_hidden_layers - 1:

tensorrt_llm/_torch/models/modeling_qwen3_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ def __init__(
417417
)
418418
self.preload_weight_modules = self.model.preload_weight_modules
419419

420-
def post_load_weights(self):
420+
def setup_aliases(self):
421421
for idx, layer in enumerate(
422422
self.model.layers[:self.config.num_hidden_layers]):
423423
if idx == self.config.num_hidden_layers - 1:

tensorrt_llm/_torch/models/modeling_qwen3_next.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -977,7 +977,7 @@ def load_weights(self, weights: dict, weight_mapper: BaseWeightMapper):
977977
new_weights = weight_mapper.preprocess_weights(weights)
978978
super().load_weights(new_weights, weight_mapper)
979979

980-
def post_load_weights(self):
980+
def setup_aliases(self):
981981
for idx, layer in enumerate(
982982
self.model.layers[:self.config.num_hidden_layers]):
983983
if idx == self.config.num_hidden_layers - 1:

tensorrt_llm/_torch/pyexecutor/model_loader.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -455,8 +455,9 @@ def init_meta_tensor(t: torch.Tensor):
455455
# post_load_* hooks itself, so the shared post-load block below
456456
# must skip them. RW handles them inside `mem_pool_scope` so the
457457
# committed pool reflects the post-post_load layout; RO runs
458-
# `module.post_load_weights()` before `materialize_module` to
459-
# wire aliases prior to zero-copy mapping.
458+
# ``setup_aliases()`` before ``materialize_module`` to wire aliases
459+
# prior to zero-copy mapping, then refreshes derived state after
460+
# real GMS tensors are bound.
460461
gms_post_load_handled = False
461462
if load_format == LoadFormat.AUTO:
462463
# Pass model= so format-specific loaders (e.g. MX) can
@@ -685,31 +686,33 @@ def init_meta_tensor_in_pool(t: torch.Tensor):
685686
# presharded modules).
686687
#
687688
# Hook order:
688-
# 1. `post_load_apply`: format-specific apply
689+
# 1. ``post_load_apply``: format-specific apply
689690
# work (e.g., MX preshard markers).
690-
# 2. Per-module `post_load_weights`: creates
691-
# aliases/derived parameter attributes BEFORE
692-
# `materialize_module` walks the final module
693-
# tree (including `draft_model` for spec dec).
694-
# 3. `materialize_module`: zero-copy bind GMS
691+
# 2. Top-level ``setup_aliases``: creates structural
692+
# aliases BEFORE ``materialize_module`` walks the
693+
# final module tree (including ``draft_model`` for
694+
# spec dec).
695+
# 3. SourceIdentity gate: STRICT pre-materialize
696+
# compatibility check (GMS has no disk fallback).
697+
# 4. ``materialize_module``: zero-copy bind GMS
695698
# pool storage onto the model parameters.
696-
# 4. `post_load_publish`: any receiver-side
699+
# 5. Per-module ``cache_derived_state``: recompute
700+
# Python-side state from real, materialized
701+
# tensors without re-running one-shot transforms.
702+
# 6. ``post_load_publish``: any receiver-side
697703
# publish (no-op via the receiver guard).
698704
checkpoint_loader.post_load_apply(
699705
model, weights_preloaded=True)
700706

701-
for module in model.modules():
702-
if hasattr(module,
703-
'post_load_weights') and not getattr(
704-
module, '_weights_removed', False):
705-
module.post_load_weights()
707+
self._setup_aliases(model)
706708

707709
# Pre-materialize compatibility gate. GMS has no
708710
# disk-fallback path, so a mismatch raises under STRICT
709711
# rather than falling back.
710712
self._check_gms_source_identity(gms_backend)
711713

712714
gms_backend.materialize_module(model)
715+
self._walk_cache_state(model)
713716

714717
checkpoint_loader.post_load_publish(
715718
model,

tests/unittest/_torch/pyexecutor/test_model_loader_mx.py

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,12 @@ def load_weights(self, weights, mapper):
4848
def load_draft_weights(self, weights, mapper):
4949
self._events.append("load_draft_weights")
5050

51+
def setup_aliases(self):
52+
self._events.append("setup_aliases")
53+
54+
def cache_derived_state(self):
55+
self._events.append("cache_derived_state")
56+
5157
def post_load_weights(self):
5258
self._events.append("post_load_weights")
5359

@@ -57,8 +63,15 @@ def _moe_context(config, mapping):
5763
yield None
5864

5965

60-
def _make_loader(monkeypatch, *, events, spec_config=None):
61-
llm_args = SimpleNamespace(load_format=LoadFormat.AUTO)
66+
def _make_loader(monkeypatch, *, events, spec_config=None, load_format=LoadFormat.AUTO):
67+
llm_args = SimpleNamespace(
68+
load_format=load_format,
69+
gms_config=SimpleNamespace(
70+
socket_path="/tmp/gms.sock",
71+
mode="test",
72+
tag="test",
73+
),
74+
)
6275
loader = ModelLoader(
6376
llm_args=llm_args,
6477
mapping=MagicMock(name="mapping"),
@@ -75,6 +88,15 @@ def _make_loader(monkeypatch, *, events, spec_config=None):
7588
monkeypatch.setattr(model_loader_mod, "timing", lambda *_args, **_kwargs: nullcontext())
7689
monkeypatch.setattr(model_loader_mod, "maybe_create_moe_load_balancer", _moe_context)
7790
monkeypatch.setattr(model_loader_mod, "MetaInitMode", lambda: nullcontext())
91+
# The MX and GMS load paths build a receiver-side SourceIdentity from the
92+
# resolved ModelConfig. These tests stub the config, so short-circuit the
93+
# fingerprint construction to a sentinel; identity-comparison logic is
94+
# covered separately in test_source_identity.py.
95+
monkeypatch.setattr(
96+
model_loader_mod.SourceIdentity,
97+
"from_model_config",
98+
classmethod(lambda cls, *_args, **_kwargs: SimpleNamespace(name="local-identity")),
99+
)
78100
monkeypatch.setattr(
79101
model_loader_mod.AutoModelForCausalLM,
80102
"from_config",
@@ -157,6 +179,62 @@ def test_mx_fallback_runs_standard_weight_mapping(monkeypatch):
157179
)
158180

159181

182+
def test_gms_ro_materializes_between_alias_setup_and_cache_state(monkeypatch):
183+
events = []
184+
loader = _make_loader(monkeypatch, events=events, load_format=LoadFormat.GMS)
185+
checkpoint_loader = MagicMock(name="checkpoint_loader")
186+
checkpoint_loader.checkpoint_format = "GMS"
187+
188+
def record(event):
189+
def _append(*_args, **_kwargs):
190+
events.append(event)
191+
192+
return _append
193+
194+
checkpoint_loader.post_load_apply.side_effect = record("post_load_apply")
195+
checkpoint_loader.post_load_publish.side_effect = record("post_load_publish")
196+
197+
# The STRICT pre-materialize identity gate runs between alias setup and
198+
# materialization; record it to pin the ordering, without exercising the
199+
# comparison logic (covered in test_source_identity.py).
200+
monkeypatch.setattr(
201+
model_loader_mod,
202+
"check_weight_sharing_compatibility",
203+
lambda *_args, **_kwargs: events.append("check_source_identity"),
204+
)
205+
206+
class _GmsBackend:
207+
def __init__(self, *args, **kwargs):
208+
self.is_rw = False
209+
210+
def connect(self):
211+
return True
212+
213+
def get_source_identity(self):
214+
return SimpleNamespace(name="remote-identity")
215+
216+
def materialize_module(self, model):
217+
events.append("materialize_module")
218+
219+
def cleanup(self):
220+
events.append("cleanup")
221+
222+
monkeypatch.setattr("tensorrt_llm._torch.memory.GMSBackend", _GmsBackend)
223+
224+
loader.load("/ckpt", checkpoint_loader)
225+
226+
assert events == [
227+
"post_load_apply",
228+
"setup_aliases",
229+
"check_source_identity",
230+
"materialize_module",
231+
"cache_derived_state",
232+
"post_load_publish",
233+
]
234+
assert "post_load_weights" not in events
235+
checkpoint_loader.load_weights.assert_not_called()
236+
237+
160238
class _HookRecorder(nn.Module):
161239
def __init__(
162240
self,

0 commit comments

Comments
 (0)