Skip to content

Commit 4352612

Browse files
committed
[TRTLLM-13246][feat] Wave 1: migrate aliases to setup_aliases and stage GMS RO load
Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
1 parent aef7d47 commit 4352612

11 files changed

Lines changed: 139 additions & 51 deletions

File tree

tensorrt_llm/_torch/memory/gpu_memory_backend.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,14 @@
3636
CUDA memory pool. After loading, weights are committed for read-only
3737
access by other workers and the client transitions to RO mode in place.
3838
- **RO (Read-Only)**: Subsequent workers zero-copy import already-committed
39-
weights from the GMS pool. `post_load_weights()` must run BEFORE
40-
materialization so that module aliases are set up correctly.
39+
weights from the GMS pool. `setup_aliases()` must run BEFORE
40+
materialization so that module aliases are set up correctly, while derived
41+
state is refreshed after real tensors are bound. RO is validated for models
42+
whose `post_load_weights()` is pure alias wiring; models that additionally
43+
rely on plain Python attributes set inside `post_load_weights()` (rather
44+
than registered `nn.Buffer` / `nn.Parameter` assignments) need to migrate
45+
those side effects to `transform_weights()` or `cache_derived_state()`
46+
before they are safe on the RO path.
4147
"""
4248

4349
from contextlib import contextmanager
@@ -477,7 +483,7 @@ def materialize_module(self, model: nn.Module) -> None:
477483
by GPU pointers from the shared memory region — no data copies,
478484
no disk I/O, just CUDA VMM remapping. The model's submodule
479485
layout must already match the writer's at commit time, including
480-
any aliases / derived buffers introduced by `post_load_weights`.
486+
any aliases introduced by `setup_aliases`.
481487
482488
Args:
483489
model: The `nn.Module` to materialize. Walks the full
@@ -489,11 +495,10 @@ def materialize_module(self, model: nn.Module) -> None:
489495
RuntimeError: If `connect()` has not been called yet.
490496
491497
Note:
492-
`post_load_weights()` must be called on the model BEFORE
493-
this method. The order ensures that any aliases / derived
494-
parameters created by post-load hooks are present on the
495-
module tree at materialization time, so they are bound to
496-
the same GMS storage as their primary tensor.
498+
`setup_aliases()` must be called on the model BEFORE this method.
499+
The order ensures that any structural aliases created by post-load
500+
hooks are present on the module tree at materialization time, so
501+
they are bound to the same GMS storage as their primary tensor.
497502
"""
498503
if self._client is None:
499504
raise RuntimeError("GMS client not connected. Call connect() first.")

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1921,7 +1921,7 @@ def load_weights(self, weights: ConsumableWeightsDict):
19211921
weight_loader = DeepseekV3WeightLoader(self)
19221922
weight_loader.load_weights(weights)
19231923

1924-
def post_load_weights(self):
1924+
def setup_aliases(self):
19251925
for idx, layer in enumerate(
19261926
self.model.layers[:self.config.num_hidden_layers]):
19271927
if idx == self.config.num_hidden_layers - 1:

tensorrt_llm/_torch/models/modeling_exaone_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -725,7 +725,7 @@ def load_weights(
725725
allow_partial_loading=allow_partial_loading,
726726
)
727727

728-
def post_load_weights(self):
728+
def setup_aliases(self):
729729
# For the cross-layer residual+LN fusion.
730730
for idx, layer in enumerate(self.model.layers[: self.config.num_hidden_layers]):
731731
if idx == self.config.num_hidden_layers - 1:

tensorrt_llm/_torch/models/modeling_glm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1074,7 +1074,7 @@ def load_weights(self, weights: ConsumableWeightsDict, allow_partial_loading: bo
10741074
weight_loader = Glm4WeightLoader(self)
10751075
weight_loader.load_weights(weights, allow_partial_loading=allow_partial_loading)
10761076

1077-
def post_load_weights(self):
1077+
def setup_aliases(self):
10781078
for idx, layer in enumerate(self.model.layers[: self.config.num_hidden_layers]):
10791079
if idx == self.config.num_hidden_layers - 1:
10801080
layer.next_layer_layernorm = self.model.norm

tensorrt_llm/_torch/models/modeling_gpt_oss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,7 @@ def load_weights(self, weights: Dict):
631631
else:
632632
self.load_hf_weights(weights)
633633

634-
def post_load_weights(self):
634+
def setup_aliases(self):
635635
for idx, layer in enumerate(
636636
self.model.block[:self.config.num_hidden_layers]):
637637
if idx == 0:

tensorrt_llm/_torch/models/modeling_llama.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ def __init__(
484484
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
485485
eps=config.rms_norm_eps,
486486
dtype=config.torch_dtype)
487-
# When post_load_weights() chains layernorms across layers,
487+
# When setup_aliases() chains layernorms across layers,
488488
# this flag is set to True to skip the input layernorm in
489489
# forward() since it is handled by the previous layer.
490490
self.skip_input_layernorm = False
@@ -709,7 +709,7 @@ def __init__(
709709
quantize_type="nvfp4"
710710
if not self.disable_nvfp4_layernorm_fusion and self.is_nvfp4
711711
and not (differ_pp_stage_with_previous_layer) else None)
712-
# When post_load_weights() chains layernorms across layers,
712+
# When setup_aliases() chains layernorms across layers,
713713
# this flag is set to True to skip the input layernorm in
714714
# forward() since it is handled by the previous layer.
715715
self.skip_input_layernorm = False
@@ -983,7 +983,7 @@ def __init__(self, model_config: ModelConfig[LlamaConfig]):
983983
self.norm = RMSNorm(hidden_size=config.hidden_size,
984984
eps=config.rms_norm_eps,
985985
dtype=config.torch_dtype)
986-
# When post_load_weights() chains the final norm into the
986+
# When setup_aliases() chains the final norm into the
987987
# last decoder layer, this flag is set to True to skip
988988
# applying it again in forward().
989989
self.skip_norm = False
@@ -1088,7 +1088,7 @@ def __init__(self, model_config: ModelConfig[LlamaConfig]):
10881088
self.norm = RMSNorm(hidden_size=config.hidden_size,
10891089
eps=config.rms_norm_eps,
10901090
dtype=config.torch_dtype)
1091-
# When post_load_weights() chains the final norm into the
1091+
# When setup_aliases() chains the final norm into the
10921092
# last decoder layer, this flag is set to True to skip
10931093
# applying it again in forward().
10941094
self.skip_norm = False
@@ -1140,7 +1140,7 @@ def __init__(
11401140
):
11411141
super().__init__(LlamaModel(model_config), model_config)
11421142

1143-
def post_load_weights(self):
1143+
def setup_aliases(self):
11441144
for idx, layer in enumerate(
11451145
self.model.layers[:self.config.num_hidden_layers]):
11461146
if idx == self.config.num_hidden_layers - 1:
@@ -1564,7 +1564,7 @@ def load_weights(self, weights: Dict, weight_mapper: BaseWeightMapper):
15641564
if had_mm_encoder:
15651565
self.mm_encoder = saved_mm_encoder
15661566

1567-
def post_load_weights(self):
1567+
def setup_aliases(self):
15681568
for idx, layer in enumerate(
15691569
self.model.layers[:self.config.num_hidden_layers]):
15701570
if idx == self.config.num_hidden_layers - 1:

tensorrt_llm/_torch/models/modeling_qwen3_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ def __init__(
417417
)
418418
self.preload_weight_modules = self.model.preload_weight_modules
419419

420-
def post_load_weights(self):
420+
def setup_aliases(self):
421421
for idx, layer in enumerate(
422422
self.model.layers[:self.config.num_hidden_layers]):
423423
if idx == self.config.num_hidden_layers - 1:

tensorrt_llm/_torch/models/modeling_qwen3_next.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -980,7 +980,7 @@ def load_weights(self, weights: dict, weight_mapper: BaseWeightMapper):
980980
new_weights = weight_mapper.preprocess_weights(weights)
981981
super().load_weights(new_weights, weight_mapper)
982982

983-
def post_load_weights(self):
983+
def setup_aliases(self):
984984
for idx, layer in enumerate(
985985
self.model.layers[:self.config.num_hidden_layers]):
986986
if idx == self.config.num_hidden_layers - 1:

tensorrt_llm/_torch/pyexecutor/model_loader.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -462,8 +462,9 @@ def init_meta_tensor(t: torch.Tensor):
462462
# post_load_* hooks itself, so the shared post-load block below
463463
# must skip them. RW handles them inside `mem_pool_scope` so the
464464
# committed pool reflects the post-post_load layout; RO runs
465-
# `module.post_load_weights()` before `materialize_module` to
466-
# wire aliases prior to zero-copy mapping.
465+
# ``setup_aliases()`` before ``materialize_module`` to wire aliases
466+
# prior to zero-copy mapping, then refreshes derived state after
467+
# real GMS tensors are bound.
467468
gms_post_load_handled = False
468469
if load_format == LoadFormat.AUTO:
469470
# Pass model= so format-specific loaders (e.g. MX) can
@@ -692,31 +693,33 @@ def init_meta_tensor_in_pool(t: torch.Tensor):
692693
# presharded modules).
693694
#
694695
# Hook order:
695-
# 1. `post_load_apply`: format-specific apply
696+
# 1. ``post_load_apply``: format-specific apply
696697
# work (e.g., MX preshard markers).
697-
# 2. Per-module `post_load_weights`: creates
698-
# aliases/derived parameter attributes BEFORE
699-
# `materialize_module` walks the final module
700-
# tree (including `draft_model` for spec dec).
701-
# 3. `materialize_module`: zero-copy bind GMS
698+
# 2. Per-module ``setup_aliases``: creates structural
699+
# aliases BEFORE ``materialize_module`` walks the
700+
# final module tree (including ``draft_model`` for
701+
# spec dec).
702+
# 3. SourceIdentity gate: STRICT pre-materialize
703+
# compatibility check (GMS has no disk fallback).
704+
# 4. ``materialize_module``: zero-copy bind GMS
702705
# pool storage onto the model parameters.
703-
# 4. `post_load_publish`: any receiver-side
706+
# 5. Per-module ``cache_derived_state``: recompute
707+
# Python-side state from real, materialized
708+
# tensors without re-running one-shot transforms.
709+
# 6. ``post_load_publish``: any receiver-side
704710
# publish (no-op via the receiver guard).
705711
checkpoint_loader.post_load_apply(
706712
model, weights_preloaded=True)
707713

708-
for module in model.modules():
709-
if hasattr(module,
710-
'post_load_weights') and not getattr(
711-
module, '_weights_removed', False):
712-
module.post_load_weights()
714+
self._setup_aliases(model)
713715

714716
# Pre-materialize compatibility gate. GMS has no
715717
# disk-fallback path, so a mismatch raises under STRICT
716718
# rather than falling back.
717719
self._check_gms_source_identity(gms_backend)
718720

719721
gms_backend.materialize_module(model)
722+
self._walk_cache_state(model)
720723

721724
checkpoint_loader.post_load_publish(
722725
model,
@@ -806,22 +809,24 @@ def _check_gms_source_identity(self, gms_backend) -> None:
806809

807810
@staticmethod
808811
def _setup_aliases(model: DecoderModelForCausalLM) -> None:
809-
"""Run top-level structural alias setup if the model defines it.
812+
"""Run structural alias setup on eligible modules.
810813
811-
Alias wiring is a model-level concern. It is intentionally not a
812-
recursive module walk, because migrated aliases are expected to be set
813-
by the root model that owns the layer graph.
814+
The walk is duck-typed so modules can opt in without inheriting a
815+
shared base class. Modules whose weights were removed are skipped,
816+
matching the legacy full post-load walk.
814817
815818
Args:
816-
model: Root decoder model whose top-level alias hook should run.
819+
model: Root decoder model whose module tree should be visited.
817820
818821
Returns:
819822
None.
820823
"""
821-
setup_aliases: Optional[Callable[[], None]] = getattr(
822-
model, 'setup_aliases', None)
823-
if setup_aliases is not None:
824-
setup_aliases()
824+
for module in model.modules():
825+
setup_aliases: Optional[Callable[[], None]] = getattr(
826+
module, 'setup_aliases', None)
827+
if setup_aliases is not None and not getattr(
828+
module, '_weights_removed', False):
829+
setup_aliases()
825830

826831
@staticmethod
827832
def _walk_transform(model: DecoderModelForCausalLM) -> None:
@@ -929,6 +934,7 @@ def reload(self,
929934
"Cannot reload weights: weight_mapper was not initialized. "
930935
"This can happen when the initial load used GMS, MX P2P, or "
931936
"VISION_ONLY, which bypass the standard weight mapping path.")
937+
self._reset_weights_transformed(model)
932938
self._call_load_weights(model.load_weights,
933939
weights,
934940
self.weight_mapper,

tests/unittest/_torch/pyexecutor/test_model_loader_gms.py

Lines changed: 66 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ def to(self, *args, **kwargs):
3434
def load_weights(self, weights, mapper):
3535
self._events.append("load_weights")
3636

37+
def setup_aliases(self):
38+
self._events.append("setup_aliases")
39+
40+
def cache_derived_state(self):
41+
self._events.append("cache_derived_state")
42+
3743
def post_load_weights(self):
3844
self._events.append("post_load_weights")
3945

@@ -66,11 +72,21 @@ def _make_loader(monkeypatch, *, events, spec_config=None):
6672
loader._call_load_weights = MagicMock(
6773
side_effect=lambda fn, weights, mapper, **kwargs: fn(weights, mapper)
6874
)
69-
loader._load_and_validate_config = MagicMock(return_value=SimpleNamespace(name="config"))
75+
loader._load_and_validate_config = MagicMock(
76+
return_value=SimpleNamespace(name="config", mapping=SimpleNamespace())
77+
)
7078

7179
monkeypatch.setattr(model_loader_mod, "timing", lambda *_args, **_kwargs: nullcontext())
7280
monkeypatch.setattr(model_loader_mod, "maybe_create_moe_load_balancer", _moe_context)
7381
monkeypatch.setattr(model_loader_mod, "MetaInitMode", lambda: nullcontext())
82+
# GMS builds a receiver-side SourceIdentity from the resolved ModelConfig.
83+
# These tests stub the config, so short-circuit fingerprint construction to
84+
# a sentinel; identity-comparison behavior is covered in test_source_identity.py.
85+
monkeypatch.setattr(
86+
model_loader_mod.SourceIdentity,
87+
"from_model_config",
88+
classmethod(lambda cls, *_args, **_kwargs: SimpleNamespace(name="local-identity")),
89+
)
7490
monkeypatch.setattr(
7591
model_loader_mod.AutoModelForCausalLM,
7692
"from_config",
@@ -93,6 +109,7 @@ def _build_gms_backend(*, is_rw, events):
93109
if is_rw:
94110
backend.mem_pool_scope.side_effect = lambda _device: _pool_scope(events)
95111
else:
112+
backend.get_source_identity.return_value = None
96113

97114
def _materialize(_model):
98115
events.append("materialize")
@@ -129,7 +146,7 @@ def _spec_config_needing_draft_weights():
129146
),
130147
pytest.param(
131148
False,
132-
["post_load_weights", "materialize"],
149+
["setup_aliases", "materialize", "cache_derived_state"],
133150
id="ro",
134151
),
135152
],
@@ -143,8 +160,9 @@ def test_gms_load_branch(monkeypatch, is_rw, expected_events):
143160
(``_apply`` for meta materialization, ``to('cuda')``, weight
144161
load, ``post_load_weights``) inside the pool, then commits via
145162
``finalize_write`` once the scope exits.
146-
ro: the reader runs ``post_load_weights`` to wire module aliases
147-
first, then GMS materializes weights via zero-copy mapping.
163+
ro: the reader runs ``setup_aliases`` to wire module aliases, checks
164+
identity compatibility, materializes weights via zero-copy mapping,
165+
then refreshes derived state from real tensors.
148166
"""
149167
events = []
150168
loader = _make_loader(monkeypatch, events=events)
@@ -175,13 +193,55 @@ def test_gms_load_branch(monkeypatch, is_rw, expected_events):
175193
backend.move_untracked_params.assert_called_once_with(model)
176194
backend.finalize_write.assert_called_once_with(model)
177195
else:
178-
# RO: post_load_weights() must run before the GMS materialize
179-
# step so module aliases are wired up before zero-copy mapping.
196+
# RO: setup_aliases() must run before the GMS materialize step so
197+
# module aliases are wired up before zero-copy mapping.
180198
checkpoint_loader.load_weights.assert_not_called()
181199
loader._call_load_weights.assert_not_called()
182200
backend.materialize_module.assert_called_once_with(model)
183201

184202

203+
def test_gms_ro_materializes_between_alias_setup_and_cache_state(monkeypatch):
204+
events = []
205+
loader = _make_loader(monkeypatch, events=events)
206+
backend = _build_gms_backend(is_rw=False, events=events)
207+
_install_gms_backend(monkeypatch, backend)
208+
209+
checkpoint_loader = MagicMock(name="checkpoint_loader")
210+
checkpoint_loader.checkpoint_format = "HF"
211+
212+
def record(event):
213+
def _append(*_args, **_kwargs):
214+
events.append(event)
215+
216+
return _append
217+
218+
checkpoint_loader.post_load_apply.side_effect = record("post_load_apply")
219+
checkpoint_loader.post_load_publish.side_effect = record("post_load_publish")
220+
221+
# The STRICT pre-materialize identity gate runs between alias setup and
222+
# materialization; record it to pin the ordering without exercising the
223+
# comparison logic, which is covered in test_source_identity.py.
224+
monkeypatch.setattr(
225+
model_loader_mod,
226+
"check_weight_sharing_compatibility",
227+
lambda *_args, **_kwargs: events.append("check_source_identity"),
228+
)
229+
230+
loader.load("/ckpt", checkpoint_loader)
231+
232+
assert events == [
233+
"post_load_apply",
234+
"setup_aliases",
235+
"check_source_identity",
236+
"materialize",
237+
"cache_derived_state",
238+
"post_load_publish",
239+
]
240+
assert "post_load_weights" not in events
241+
checkpoint_loader.load_weights.assert_not_called()
242+
backend.materialize_module.assert_called_once()
243+
244+
185245
def test_gms_rw_post_load_runs_inside_pool_before_finalize(monkeypatch):
186246
"""Every step that may allocate or rebind tensors must run inside the GMS pool.
187247

0 commit comments

Comments
 (0)