Skip to content

Commit 85973b0

Browse files
committed
[TRTLLM-13246][feat] Wave 1: migrate aliases to setup_aliases and stage GMS RO load
Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
1 parent 09449d4 commit 85973b0

11 files changed

Lines changed: 172 additions & 58 deletions

File tree

tensorrt_llm/_torch/memory/gpu_memory_backend.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,16 @@
3636
CUDA memory pool. After loading, weights are committed for read-only
3737
access by other workers and the client transitions to RO mode in place.
3838
- **RO (Read-Only)**: Subsequent workers zero-copy import already-committed
39-
weights from the GMS pool. `post_load_weights()` must run BEFORE
40-
materialization so that module aliases are set up correctly.
39+
weights from the GMS pool. `setup_aliases()` must run BEFORE
40+
materialization so that module aliases are set up correctly, while derived
41+
state is refreshed after real tensors are bound. RO is validated for models
42+
whose `post_load_weights()` is pure alias wiring; models that additionally
43+
rely on plain Python attributes set inside `post_load_weights()` (rather
44+
than registered `nn.Buffer` / `nn.Parameter` assignments) need to migrate
45+
those side effects to `cache_derived_state()` or another path that runs on
46+
RO readers. The GMS RO reader runs `setup_aliases()` before
47+
`materialize_module()` and `cache_derived_state()` afterward; it does not
48+
run `transform_weights()`.
4149
"""
4250

4351
from contextlib import contextmanager
@@ -477,7 +485,7 @@ def materialize_module(self, model: nn.Module) -> None:
477485
by GPU pointers from the shared memory region — no data copies,
478486
no disk I/O, just CUDA VMM remapping. The model's submodule
479487
layout must already match the writer's at commit time, including
480-
any aliases / derived buffers introduced by `post_load_weights`.
488+
any aliases introduced by `setup_aliases`.
481489
482490
Args:
483491
model: The `nn.Module` to materialize. Walks the full
@@ -489,11 +497,10 @@ def materialize_module(self, model: nn.Module) -> None:
489497
RuntimeError: If `connect()` has not been called yet.
490498
491499
Note:
492-
`post_load_weights()` must be called on the model BEFORE
493-
this method. The order ensures that any aliases / derived
494-
parameters created by post-load hooks are present on the
495-
module tree at materialization time, so they are bound to
496-
the same GMS storage as their primary tensor.
500+
`setup_aliases()` must be called on the model BEFORE this method.
501+
The order ensures that any structural aliases created by post-load
502+
hooks are present on the module tree at materialization time, so
503+
they are bound to the same GMS storage as their primary tensor.
497504
"""
498505
if self._client is None:
499506
raise RuntimeError("GMS client not connected. Call connect() first.")

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1921,7 +1921,7 @@ def load_weights(self, weights: ConsumableWeightsDict):
19211921
weight_loader = DeepseekV3WeightLoader(self)
19221922
weight_loader.load_weights(weights)
19231923

1924-
def post_load_weights(self):
1924+
def setup_aliases(self) -> None:
19251925
for idx, layer in enumerate(
19261926
self.model.layers[:self.config.num_hidden_layers]):
19271927
if idx == self.config.num_hidden_layers - 1:

tensorrt_llm/_torch/models/modeling_exaone_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -725,7 +725,7 @@ def load_weights(
725725
allow_partial_loading=allow_partial_loading,
726726
)
727727

728-
def post_load_weights(self):
728+
def setup_aliases(self) -> None:
729729
# For the cross-layer residual+LN fusion.
730730
for idx, layer in enumerate(self.model.layers[: self.config.num_hidden_layers]):
731731
if idx == self.config.num_hidden_layers - 1:

tensorrt_llm/_torch/models/modeling_glm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1074,7 +1074,7 @@ def load_weights(self, weights: ConsumableWeightsDict, allow_partial_loading: bo
10741074
weight_loader = Glm4WeightLoader(self)
10751075
weight_loader.load_weights(weights, allow_partial_loading=allow_partial_loading)
10761076

1077-
def post_load_weights(self):
1077+
def setup_aliases(self) -> None:
10781078
for idx, layer in enumerate(self.model.layers[: self.config.num_hidden_layers]):
10791079
if idx == self.config.num_hidden_layers - 1:
10801080
layer.next_layer_layernorm = self.model.norm

tensorrt_llm/_torch/models/modeling_gpt_oss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,7 @@ def load_weights(self, weights: Dict):
631631
else:
632632
self.load_hf_weights(weights)
633633

634-
def post_load_weights(self):
634+
def setup_aliases(self) -> None:
635635
for idx, layer in enumerate(
636636
self.model.block[:self.config.num_hidden_layers]):
637637
if idx == 0:

tensorrt_llm/_torch/models/modeling_llama.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ def __init__(
484484
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
485485
eps=config.rms_norm_eps,
486486
dtype=config.torch_dtype)
487-
# When post_load_weights() chains layernorms across layers,
487+
# When setup_aliases() chains layernorms across layers,
488488
# this flag is set to True to skip the input layernorm in
489489
# forward() since it is handled by the previous layer.
490490
self.skip_input_layernorm = False
@@ -709,7 +709,7 @@ def __init__(
709709
quantize_type="nvfp4"
710710
if not self.disable_nvfp4_layernorm_fusion and self.is_nvfp4
711711
and not (differ_pp_stage_with_previous_layer) else None)
712-
# When post_load_weights() chains layernorms across layers,
712+
# When setup_aliases() chains layernorms across layers,
713713
# this flag is set to True to skip the input layernorm in
714714
# forward() since it is handled by the previous layer.
715715
self.skip_input_layernorm = False
@@ -983,7 +983,7 @@ def __init__(self, model_config: ModelConfig[LlamaConfig]):
983983
self.norm = RMSNorm(hidden_size=config.hidden_size,
984984
eps=config.rms_norm_eps,
985985
dtype=config.torch_dtype)
986-
# When post_load_weights() chains the final norm into the
986+
# When setup_aliases() chains the final norm into the
987987
# last decoder layer, this flag is set to True to skip
988988
# applying it again in forward().
989989
self.skip_norm = False
@@ -1088,7 +1088,7 @@ def __init__(self, model_config: ModelConfig[LlamaConfig]):
10881088
self.norm = RMSNorm(hidden_size=config.hidden_size,
10891089
eps=config.rms_norm_eps,
10901090
dtype=config.torch_dtype)
1091-
# When post_load_weights() chains the final norm into the
1091+
# When setup_aliases() chains the final norm into the
10921092
# last decoder layer, this flag is set to True to skip
10931093
# applying it again in forward().
10941094
self.skip_norm = False
@@ -1140,7 +1140,7 @@ def __init__(
11401140
):
11411141
super().__init__(LlamaModel(model_config), model_config)
11421142

1143-
def post_load_weights(self):
1143+
def setup_aliases(self) -> None:
11441144
for idx, layer in enumerate(
11451145
self.model.layers[:self.config.num_hidden_layers]):
11461146
if idx == self.config.num_hidden_layers - 1:
@@ -1564,7 +1564,7 @@ def load_weights(self, weights: Dict, weight_mapper: BaseWeightMapper):
15641564
if had_mm_encoder:
15651565
self.mm_encoder = saved_mm_encoder
15661566

1567-
def post_load_weights(self):
1567+
def setup_aliases(self) -> None:
15681568
for idx, layer in enumerate(
15691569
self.model.layers[:self.config.num_hidden_layers]):
15701570
if idx == self.config.num_hidden_layers - 1:

tensorrt_llm/_torch/models/modeling_qwen3_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ def __init__(
417417
)
418418
self.preload_weight_modules = self.model.preload_weight_modules
419419

420-
def post_load_weights(self):
420+
def setup_aliases(self) -> None:
421421
for idx, layer in enumerate(
422422
self.model.layers[:self.config.num_hidden_layers]):
423423
if idx == self.config.num_hidden_layers - 1:

tensorrt_llm/_torch/models/modeling_qwen3_next.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -980,7 +980,7 @@ def load_weights(self, weights: dict, weight_mapper: BaseWeightMapper):
980980
new_weights = weight_mapper.preprocess_weights(weights)
981981
super().load_weights(new_weights, weight_mapper)
982982

983-
def post_load_weights(self):
983+
def setup_aliases(self) -> None:
984984
for idx, layer in enumerate(
985985
self.model.layers[:self.config.num_hidden_layers]):
986986
if idx == self.config.num_hidden_layers - 1:

tensorrt_llm/_torch/pyexecutor/model_loader.py

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -462,8 +462,9 @@ def init_meta_tensor(t: torch.Tensor):
462462
# post_load_* hooks itself, so the shared post-load block below
463463
# must skip them. RW handles them inside `mem_pool_scope` so the
464464
# committed pool reflects the post-post_load layout; RO runs
465-
# `module.post_load_weights()` before `materialize_module` to
466-
# wire aliases prior to zero-copy mapping.
465+
# `setup_aliases()` before `materialize_module` to wire aliases
466+
# prior to zero-copy mapping, then refreshes derived state after
467+
# real GMS tensors are bound.
467468
gms_post_load_handled = False
468469
if load_format == LoadFormat.AUTO:
469470
# Pass model= so format-specific loaders (e.g. MX) can
@@ -694,29 +695,31 @@ def init_meta_tensor_in_pool(t: torch.Tensor):
694695
# Hook order:
695696
# 1. `post_load_apply`: format-specific apply
696697
# work (e.g., MX preshard markers).
697-
# 2. Per-module `post_load_weights`: creates
698-
# aliases/derived parameter attributes BEFORE
699-
# `materialize_module` walks the final module
700-
# tree (including `draft_model` for spec dec).
701-
# 3. `materialize_module`: zero-copy bind GMS
698+
# 2. Per-module `setup_aliases`: creates structural
699+
# aliases BEFORE `materialize_module` walks the
700+
# final module tree (including `draft_model` for
701+
# spec dec).
702+
# 3. SourceIdentity gate: STRICT pre-materialize
703+
# compatibility check (GMS has no disk fallback).
704+
# 4. `materialize_module`: zero-copy bind GMS
702705
# pool storage onto the model parameters.
703-
# 4. `post_load_publish`: any receiver-side
706+
# 5. Per-module `cache_derived_state`: recompute
707+
# Python-side state from real, materialized
708+
# tensors without re-running one-shot transforms.
709+
# 6. `post_load_publish`: any receiver-side
704710
# publish (no-op via the receiver guard).
705711
checkpoint_loader.post_load_apply(
706712
model, weights_preloaded=True)
707713

708-
for module in model.modules():
709-
if hasattr(module,
710-
'post_load_weights') and not getattr(
711-
module, '_weights_removed', False):
712-
module.post_load_weights()
714+
self._setup_aliases(model)
713715

714716
# Pre-materialize compatibility gate. GMS has no
715717
# disk-fallback path, so a mismatch raises under STRICT
716718
# rather than falling back.
717719
self._check_gms_source_identity(gms_backend)
718720

719721
gms_backend.materialize_module(model)
722+
self._walk_cache_state(model)
720723

721724
checkpoint_loader.post_load_publish(
722725
model,
@@ -806,22 +809,24 @@ def _check_gms_source_identity(self, gms_backend) -> None:
806809

807810
@staticmethod
808811
def _setup_aliases(model: DecoderModelForCausalLM) -> None:
809-
"""Run top-level structural alias setup if the model defines it.
812+
"""Run structural alias setup on eligible modules.
810813
811-
Alias wiring is a model-level concern. It is intentionally not a
812-
recursive module walk, because migrated aliases are expected to be set
813-
by the root model that owns the layer graph.
814+
The walk is duck-typed so modules can opt in without inheriting a
815+
shared base class. Modules whose weights were removed are skipped,
816+
matching the legacy full post-load walk.
814817
815818
Args:
816-
model: Root decoder model whose top-level alias hook should run.
819+
model: Root decoder model whose module tree should be visited.
817820
818821
Returns:
819822
None.
820823
"""
821-
setup_aliases: Optional[Callable[[], None]] = getattr(
822-
model, 'setup_aliases', None)
823-
if setup_aliases is not None:
824-
setup_aliases()
824+
for module in model.modules():
825+
setup_aliases: Optional[Callable[[], None]] = getattr(
826+
module, 'setup_aliases', None)
827+
if setup_aliases is not None and not getattr(
828+
module, '_weights_removed', False):
829+
setup_aliases()
825830

826831
@staticmethod
827832
def _walk_transform(model: DecoderModelForCausalLM) -> None:
@@ -912,8 +917,11 @@ def reload(self,
912917
"""Reload model weights without running post-load hooks.
913918
914919
Reload is used by incremental update paths that may provide only a
915-
partial set of replacement weights. The owner of the update lifecycle is
916-
responsible for running post-load processing once all bytes are present.
920+
partial set of replacement weights. Full reloads reset transform guards
921+
before rebinding fresh weights. Partial reloads keep existing transform
922+
guards intact because untouched modules may already contain transformed
923+
live weights. The owner of the update lifecycle is responsible for
924+
running post-load processing once all bytes are present.
917925
918926
Args:
919927
model: Model instance receiving the replacement weights.
@@ -929,6 +937,8 @@ def reload(self,
929937
"Cannot reload weights: weight_mapper was not initialized. "
930938
"This can happen when the initial load used GMS, MX P2P, or "
931939
"VISION_ONLY, which bypass the standard weight mapping path.")
940+
if not allow_partial_loading:
941+
self._reset_weights_transformed(model)
932942
self._call_load_weights(model.load_weights,
933943
weights,
934944
self.weight_mapper,

tests/unittest/_torch/pyexecutor/test_model_loader_gms.py

Lines changed: 71 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,13 @@ def to(self, *args, **kwargs):
3434
def load_weights(self, weights, mapper):
3535
self._events.append("load_weights")
3636

37-
def post_load_weights(self):
37+
def setup_aliases(self) -> None:
38+
self._events.append("setup_aliases")
39+
40+
def cache_derived_state(self) -> None:
41+
self._events.append("cache_derived_state")
42+
43+
def post_load_weights(self) -> None:
3844
self._events.append("post_load_weights")
3945

4046

@@ -66,11 +72,21 @@ def _make_loader(monkeypatch, *, events, spec_config=None):
6672
loader._call_load_weights = MagicMock(
6773
side_effect=lambda fn, weights, mapper, **kwargs: fn(weights, mapper)
6874
)
69-
loader._load_and_validate_config = MagicMock(return_value=SimpleNamespace(name="config"))
75+
loader._load_and_validate_config = MagicMock(
76+
return_value=SimpleNamespace(name="config", mapping=SimpleNamespace())
77+
)
7078

7179
monkeypatch.setattr(model_loader_mod, "timing", lambda *_args, **_kwargs: nullcontext())
7280
monkeypatch.setattr(model_loader_mod, "maybe_create_moe_load_balancer", _moe_context)
7381
monkeypatch.setattr(model_loader_mod, "MetaInitMode", lambda: nullcontext())
82+
# GMS builds a receiver-side SourceIdentity from the resolved ModelConfig.
83+
# These tests stub the config, so short-circuit fingerprint construction to
84+
# a sentinel; identity-comparison behavior is covered in test_source_identity.py.
85+
monkeypatch.setattr(
86+
model_loader_mod.SourceIdentity,
87+
"from_model_config",
88+
classmethod(lambda cls, *_args, **_kwargs: SimpleNamespace(name="local-identity")),
89+
)
7490
monkeypatch.setattr(
7591
model_loader_mod.AutoModelForCausalLM,
7692
"from_config",
@@ -93,6 +109,7 @@ def _build_gms_backend(*, is_rw, events):
93109
if is_rw:
94110
backend.mem_pool_scope.side_effect = lambda _device: _pool_scope(events)
95111
else:
112+
backend.get_source_identity.return_value = None
96113

97114
def _materialize(_model):
98115
events.append("materialize")
@@ -129,7 +146,7 @@ def _spec_config_needing_draft_weights():
129146
),
130147
pytest.param(
131148
False,
132-
["post_load_weights", "materialize"],
149+
["setup_aliases", "materialize", "cache_derived_state"],
133150
id="ro",
134151
),
135152
],
@@ -143,8 +160,9 @@ def test_gms_load_branch(monkeypatch, is_rw, expected_events):
143160
(``_apply`` for meta materialization, ``to('cuda')``, weight
144161
load, ``post_load_weights``) inside the pool, then commits via
145162
``finalize_write`` once the scope exits.
146-
ro: the reader runs ``post_load_weights`` to wire module aliases
147-
first, then GMS materializes weights via zero-copy mapping.
163+
ro: the reader runs ``setup_aliases`` to wire module aliases, checks
164+
identity compatibility, materializes weights via zero-copy mapping,
165+
then refreshes derived state from real tensors.
148166
"""
149167
events = []
150168
loader = _make_loader(monkeypatch, events=events)
@@ -169,19 +187,64 @@ def test_gms_load_branch(monkeypatch, is_rw, expected_events):
169187
# path (see model_loader.py); HF ignores it, MX uses it for direct
170188
# P2P writes when MX+GMS composition eventually lands.
171189
checkpoint_loader.load_weights.assert_called_once_with(
172-
"/ckpt", mapping=loader.mapping, model=model
190+
"/ckpt",
191+
mapping=loader.mapping,
192+
model=model,
193+
source_identity=loader._source_identity,
173194
)
174195
loader._call_load_weights.assert_called_once()
175196
backend.move_untracked_params.assert_called_once_with(model)
176197
backend.finalize_write.assert_called_once_with(model)
177198
else:
178-
# RO: post_load_weights() must run before the GMS materialize
179-
# step so module aliases are wired up before zero-copy mapping.
199+
# RO: setup_aliases() must run before the GMS materialize step so
200+
# module aliases are wired up before zero-copy mapping.
180201
checkpoint_loader.load_weights.assert_not_called()
181202
loader._call_load_weights.assert_not_called()
182203
backend.materialize_module.assert_called_once_with(model)
183204

184205

206+
def test_gms_ro_materializes_between_alias_setup_and_cache_state(monkeypatch):
207+
events = []
208+
loader = _make_loader(monkeypatch, events=events)
209+
backend = _build_gms_backend(is_rw=False, events=events)
210+
_install_gms_backend(monkeypatch, backend)
211+
212+
checkpoint_loader = MagicMock(name="checkpoint_loader")
213+
checkpoint_loader.checkpoint_format = "HF"
214+
215+
def record(event):
216+
def _append(*_args, **_kwargs):
217+
events.append(event)
218+
219+
return _append
220+
221+
checkpoint_loader.post_load_apply.side_effect = record("post_load_apply")
222+
checkpoint_loader.post_load_publish.side_effect = record("post_load_publish")
223+
224+
# The STRICT pre-materialize identity gate runs between alias setup and
225+
# materialization; record it to pin the ordering without exercising the
226+
# comparison logic, which is covered in test_source_identity.py.
227+
monkeypatch.setattr(
228+
model_loader_mod,
229+
"check_weight_sharing_compatibility",
230+
lambda *_args, **_kwargs: events.append("check_source_identity"),
231+
)
232+
233+
loader.load("/ckpt", checkpoint_loader)
234+
235+
assert events == [
236+
"post_load_apply",
237+
"setup_aliases",
238+
"check_source_identity",
239+
"materialize",
240+
"cache_derived_state",
241+
"post_load_publish",
242+
]
243+
assert "post_load_weights" not in events
244+
checkpoint_loader.load_weights.assert_not_called()
245+
backend.materialize_module.assert_called_once()
246+
247+
185248
def test_gms_rw_post_load_runs_inside_pool_before_finalize(monkeypatch):
186249
"""Every step that may allocate or rebind tensors must run inside the GMS pool.
187250

0 commit comments

Comments
 (0)