Skip to content

Commit 7eec9fe

Browse files
committed
[TRTLLM-13246][feat] Wave 1: migrate aliases to setup_aliases and stage GMS RO load
Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
1 parent 55fa04d commit 7eec9fe

11 files changed

Lines changed: 150 additions & 57 deletions

File tree

tensorrt_llm/_torch/memory/gpu_memory_backend.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,16 @@
3636
CUDA memory pool. After loading, weights are committed for read-only
3737
access by other workers and the client transitions to RO mode in place.
3838
- **RO (Read-Only)**: Subsequent workers zero-copy import already-committed
39-
weights from the GMS pool. `post_load_weights()` must run BEFORE
40-
materialization so that module aliases are set up correctly.
39+
weights from the GMS pool. `setup_aliases()` must run BEFORE
40+
materialization so that module aliases are set up correctly, while derived
41+
state is refreshed after real tensors are bound. RO is validated for models
42+
whose `post_load_weights()` is pure alias wiring; models that additionally
43+
rely on plain Python attributes set inside `post_load_weights()` (rather
44+
than registered `nn.Buffer` / `nn.Parameter` assignments) need to migrate
45+
those side effects to `cache_derived_state()` or another path that runs on
46+
RO readers. The GMS RO reader runs `setup_aliases()` before
47+
`materialize_module()` and `cache_derived_state()` afterward; it does not
48+
run `transform_weights()`.
4149
"""
4250

4351
from contextlib import contextmanager
@@ -477,7 +485,7 @@ def materialize_module(self, model: nn.Module) -> None:
477485
by GPU pointers from the shared memory region — no data copies,
478486
no disk I/O, just CUDA VMM remapping. The model's submodule
479487
layout must already match the writer's at commit time, including
480-
any aliases / derived buffers introduced by `post_load_weights`.
488+
any aliases introduced by `setup_aliases`.
481489
482490
Args:
483491
model: The `nn.Module` to materialize. Walks the full
@@ -489,11 +497,10 @@ def materialize_module(self, model: nn.Module) -> None:
489497
RuntimeError: If `connect()` has not been called yet.
490498
491499
Note:
492-
`post_load_weights()` must be called on the model BEFORE
493-
this method. The order ensures that any aliases / derived
494-
parameters created by post-load hooks are present on the
495-
module tree at materialization time, so they are bound to
496-
the same GMS storage as their primary tensor.
500+
`setup_aliases()` must be called on the model BEFORE this method.
501+
The order ensures that any structural aliases created by post-load
502+
hooks are present on the module tree at materialization time, so
503+
they are bound to the same GMS storage as their primary tensor.
497504
"""
498505
if self._client is None:
499506
raise RuntimeError("GMS client not connected. Call connect() first.")

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1921,7 +1921,7 @@ def load_weights(self, weights: ConsumableWeightsDict):
19211921
weight_loader = DeepseekV3WeightLoader(self)
19221922
weight_loader.load_weights(weights)
19231923

1924-
def post_load_weights(self):
1924+
def setup_aliases(self) -> None:
19251925
for idx, layer in enumerate(
19261926
self.model.layers[:self.config.num_hidden_layers]):
19271927
if idx == self.config.num_hidden_layers - 1:

tensorrt_llm/_torch/models/modeling_exaone_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -725,7 +725,7 @@ def load_weights(
725725
allow_partial_loading=allow_partial_loading,
726726
)
727727

728-
def post_load_weights(self):
728+
def setup_aliases(self) -> None:
729729
# For the cross-layer residual+LN fusion.
730730
for idx, layer in enumerate(self.model.layers[: self.config.num_hidden_layers]):
731731
if idx == self.config.num_hidden_layers - 1:

tensorrt_llm/_torch/models/modeling_glm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1074,7 +1074,7 @@ def load_weights(self, weights: ConsumableWeightsDict, allow_partial_loading: bo
10741074
weight_loader = Glm4WeightLoader(self)
10751075
weight_loader.load_weights(weights, allow_partial_loading=allow_partial_loading)
10761076

1077-
def post_load_weights(self):
1077+
def setup_aliases(self) -> None:
10781078
for idx, layer in enumerate(self.model.layers[: self.config.num_hidden_layers]):
10791079
if idx == self.config.num_hidden_layers - 1:
10801080
layer.next_layer_layernorm = self.model.norm

tensorrt_llm/_torch/models/modeling_gpt_oss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,7 @@ def load_weights(self, weights: Dict):
631631
else:
632632
self.load_hf_weights(weights)
633633

634-
def post_load_weights(self):
634+
def setup_aliases(self) -> None:
635635
for idx, layer in enumerate(
636636
self.model.block[:self.config.num_hidden_layers]):
637637
if idx == 0:

tensorrt_llm/_torch/models/modeling_llama.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ def __init__(
484484
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
485485
eps=config.rms_norm_eps,
486486
dtype=config.torch_dtype)
487-
# When post_load_weights() chains layernorms across layers,
487+
# When setup_aliases() chains layernorms across layers,
488488
# this flag is set to True to skip the input layernorm in
489489
# forward() since it is handled by the previous layer.
490490
self.skip_input_layernorm = False
@@ -709,7 +709,7 @@ def __init__(
709709
quantize_type="nvfp4"
710710
if not self.disable_nvfp4_layernorm_fusion and self.is_nvfp4
711711
and not (differ_pp_stage_with_previous_layer) else None)
712-
# When post_load_weights() chains layernorms across layers,
712+
# When setup_aliases() chains layernorms across layers,
713713
# this flag is set to True to skip the input layernorm in
714714
# forward() since it is handled by the previous layer.
715715
self.skip_input_layernorm = False
@@ -983,7 +983,7 @@ def __init__(self, model_config: ModelConfig[LlamaConfig]):
983983
self.norm = RMSNorm(hidden_size=config.hidden_size,
984984
eps=config.rms_norm_eps,
985985
dtype=config.torch_dtype)
986-
# When post_load_weights() chains the final norm into the
986+
# When setup_aliases() chains the final norm into the
987987
# last decoder layer, this flag is set to True to skip
988988
# applying it again in forward().
989989
self.skip_norm = False
@@ -1088,7 +1088,7 @@ def __init__(self, model_config: ModelConfig[LlamaConfig]):
10881088
self.norm = RMSNorm(hidden_size=config.hidden_size,
10891089
eps=config.rms_norm_eps,
10901090
dtype=config.torch_dtype)
1091-
# When post_load_weights() chains the final norm into the
1091+
# When setup_aliases() chains the final norm into the
10921092
# last decoder layer, this flag is set to True to skip
10931093
# applying it again in forward().
10941094
self.skip_norm = False
@@ -1140,7 +1140,7 @@ def __init__(
11401140
):
11411141
super().__init__(LlamaModel(model_config), model_config)
11421142

1143-
def post_load_weights(self):
1143+
def setup_aliases(self) -> None:
11441144
for idx, layer in enumerate(
11451145
self.model.layers[:self.config.num_hidden_layers]):
11461146
if idx == self.config.num_hidden_layers - 1:
@@ -1564,7 +1564,7 @@ def load_weights(self, weights: Dict, weight_mapper: BaseWeightMapper):
15641564
if had_mm_encoder:
15651565
self.mm_encoder = saved_mm_encoder
15661566

1567-
def post_load_weights(self):
1567+
def setup_aliases(self) -> None:
15681568
for idx, layer in enumerate(
15691569
self.model.layers[:self.config.num_hidden_layers]):
15701570
if idx == self.config.num_hidden_layers - 1:

tensorrt_llm/_torch/models/modeling_qwen3_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ def __init__(
417417
)
418418
self.preload_weight_modules = self.model.preload_weight_modules
419419

420-
def post_load_weights(self):
420+
def setup_aliases(self) -> None:
421421
for idx, layer in enumerate(
422422
self.model.layers[:self.config.num_hidden_layers]):
423423
if idx == self.config.num_hidden_layers - 1:

tensorrt_llm/_torch/models/modeling_qwen3_next.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -980,7 +980,7 @@ def load_weights(self, weights: dict, weight_mapper: BaseWeightMapper):
980980
new_weights = weight_mapper.preprocess_weights(weights)
981981
super().load_weights(new_weights, weight_mapper)
982982

983-
def post_load_weights(self):
983+
def setup_aliases(self) -> None:
984984
for idx, layer in enumerate(
985985
self.model.layers[:self.config.num_hidden_layers]):
986986
if idx == self.config.num_hidden_layers - 1:

tensorrt_llm/_torch/pyexecutor/model_loader.py

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -485,8 +485,9 @@ def init_meta_tensor(t: torch.Tensor):
485485
# post_load_* hooks itself, so the shared post-load block below
486486
# must skip them. RW handles them inside `mem_pool_scope` so the
487487
# committed pool reflects the post-post_load layout; RO runs
488-
# `module.post_load_weights()` before `materialize_module` to
489-
# wire aliases prior to zero-copy mapping.
488+
# `setup_aliases()` before `materialize_module` to wire aliases
489+
# prior to zero-copy mapping, then refreshes derived state after
490+
# real GMS tensors are bound.
490491
gms_post_load_handled = False
491492
if load_format == LoadFormat.AUTO:
492493
# Pass model= so format-specific loaders (e.g. MX) can
@@ -717,29 +718,31 @@ def init_meta_tensor_in_pool(t: torch.Tensor):
717718
# Hook order:
718719
# 1. `post_load_apply`: format-specific apply
719720
# work (e.g., MX preshard markers).
720-
# 2. Per-module `post_load_weights`: creates
721-
# aliases/derived parameter attributes BEFORE
722-
# `materialize_module` walks the final module
723-
# tree (including `draft_model` for spec dec).
724-
# 3. `materialize_module`: zero-copy bind GMS
721+
# 2. Per-module `setup_aliases`: creates structural
722+
# aliases BEFORE `materialize_module` walks the
723+
# final module tree (including `draft_model` for
724+
# spec dec).
725+
# 3. SourceIdentity gate: STRICT pre-materialize
726+
# compatibility check (GMS has no disk fallback).
727+
# 4. `materialize_module`: zero-copy bind GMS
725728
# pool storage onto the model parameters.
726-
# 4. `post_load_publish`: any receiver-side
729+
# 5. Per-module `cache_derived_state`: recompute
730+
# Python-side state from real, materialized
731+
# tensors without re-running one-shot transforms.
732+
# 6. `post_load_publish`: any receiver-side
727733
# publish (no-op via the receiver guard).
728734
checkpoint_loader.post_load_apply(
729735
model, weights_preloaded=True)
730736

731-
for module in model.modules():
732-
if hasattr(module,
733-
'post_load_weights') and not getattr(
734-
module, '_weights_removed', False):
735-
module.post_load_weights()
737+
self._setup_aliases(model)
736738

737739
# Pre-materialize compatibility gate. GMS has no
738740
# disk-fallback path, so a mismatch raises under STRICT
739741
# rather than falling back.
740742
self._check_gms_source_identity(gms_backend)
741743

742744
gms_backend.materialize_module(model)
745+
self._walk_cache_state(model)
743746

744747
checkpoint_loader.post_load_publish(
745748
model,
@@ -829,22 +832,24 @@ def _check_gms_source_identity(self, gms_backend) -> None:
829832

830833
@staticmethod
831834
def _setup_aliases(model: DecoderModelForCausalLM) -> None:
832-
"""Run top-level structural alias setup if the model defines it.
835+
"""Run structural alias setup on eligible modules.
833836
834-
Alias wiring is a model-level concern. It is intentionally not a
835-
recursive module walk, because migrated aliases are expected to be set
836-
by the root model that owns the layer graph.
837+
The walk is duck-typed so modules can opt in without inheriting a
838+
shared base class. Modules whose weights were removed are skipped,
839+
matching the legacy full post-load walk.
837840
838841
Args:
839-
model: Root decoder model whose top-level alias hook should run.
842+
model: Root decoder model whose module tree should be visited.
840843
841844
Returns:
842845
None.
843846
"""
844-
setup_aliases: Optional[Callable[[], None]] = getattr(
845-
model, 'setup_aliases', None)
846-
if setup_aliases is not None:
847-
setup_aliases()
847+
for module in model.modules():
848+
setup_aliases: Optional[Callable[[], None]] = getattr(
849+
module, 'setup_aliases', None)
850+
if setup_aliases is not None and not getattr(
851+
module, '_weights_removed', False):
852+
setup_aliases()
848853

849854
@staticmethod
850855
def _walk_transform(model: DecoderModelForCausalLM) -> None:
@@ -935,8 +940,11 @@ def reload(self,
935940
"""Reload model weights without running post-load hooks.
936941
937942
Reload is used by incremental update paths that may provide only a
938-
partial set of replacement weights. The owner of the update lifecycle is
939-
responsible for running post-load processing once all bytes are present.
943+
partial set of replacement weights. Full reloads reset transform guards
944+
before rebinding fresh weights. Partial reloads keep existing transform
945+
guards intact because untouched modules may already contain transformed
946+
live weights. The owner of the update lifecycle is responsible for
947+
running post-load processing once all bytes are present.
940948
941949
Args:
942950
model: Model instance receiving the replacement weights.
@@ -952,6 +960,8 @@ def reload(self,
952960
"Cannot reload weights: weight_mapper was not initialized. "
953961
"This can happen when the initial load used GMS, MX P2P, or "
954962
"VISION_ONLY, which bypass the standard weight mapping path.")
963+
if not allow_partial_loading:
964+
self._reset_weights_transformed(model)
955965
self._call_load_weights(model.load_weights,
956966
weights,
957967
self.weight_mapper,

tests/unittest/_torch/pyexecutor/test_model_loader_gms.py

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,13 @@ def to(self, *args, **kwargs):
4444
def load_weights(self, weights, mapper):
4545
self._events.append("load_weights")
4646

47-
def post_load_weights(self):
47+
def setup_aliases(self) -> None:
48+
self._events.append("setup_aliases")
49+
50+
def cache_derived_state(self) -> None:
51+
self._events.append("cache_derived_state")
52+
53+
def post_load_weights(self) -> None:
4854
self._events.append("post_load_weights")
4955

5056

@@ -76,7 +82,9 @@ def _make_loader(monkeypatch, *, events, spec_config=None):
7682
loader._call_load_weights = MagicMock(
7783
side_effect=lambda fn, weights, mapper, **kwargs: fn(weights, mapper)
7884
)
79-
loader._load_and_validate_config = MagicMock(return_value=SimpleNamespace(name="config"))
85+
loader._load_and_validate_config = MagicMock(
86+
return_value=SimpleNamespace(name="config", mapping=SimpleNamespace())
87+
)
8088

8189
monkeypatch.setattr(model_loader_mod, "timing", lambda *_args, **_kwargs: nullcontext())
8290
monkeypatch.setattr(model_loader_mod, "maybe_create_moe_load_balancer", _moe_context)
@@ -147,7 +155,7 @@ def _spec_config_needing_draft_weights():
147155
),
148156
pytest.param(
149157
False,
150-
["post_load_weights", "materialize"],
158+
["setup_aliases", "materialize", "cache_derived_state"],
151159
id="ro",
152160
),
153161
],
@@ -161,8 +169,9 @@ def test_gms_load_branch(monkeypatch, is_rw, expected_events):
161169
(``_apply`` for meta materialization, ``to('cuda')``, weight
162170
load, ``post_load_weights``) inside the pool, then commits via
163171
``finalize_write`` once the scope exits.
164-
ro: the reader runs ``post_load_weights`` to wire module aliases
165-
first, then GMS materializes weights via zero-copy mapping.
172+
ro: the reader runs ``setup_aliases`` to wire module aliases, checks
173+
identity compatibility, materializes weights via zero-copy mapping,
174+
then refreshes derived state from real tensors.
166175
"""
167176
events = []
168177
loader = _make_loader(monkeypatch, events=events)
@@ -196,13 +205,55 @@ def test_gms_load_branch(monkeypatch, is_rw, expected_events):
196205
backend.move_untracked_params.assert_called_once_with(model)
197206
backend.finalize_write.assert_called_once_with(model)
198207
else:
199-
# RO: post_load_weights() must run before the GMS materialize
200-
# step so module aliases are wired up before zero-copy mapping.
208+
# RO: setup_aliases() must run before the GMS materialize step so
209+
# module aliases are wired up before zero-copy mapping.
201210
checkpoint_loader.load_weights.assert_not_called()
202211
loader._call_load_weights.assert_not_called()
203212
backend.materialize_module.assert_called_once_with(model)
204213

205214

215+
def test_gms_ro_materializes_between_alias_setup_and_cache_state(monkeypatch):
216+
events = []
217+
loader = _make_loader(monkeypatch, events=events)
218+
backend = _build_gms_backend(is_rw=False, events=events)
219+
_install_gms_backend(monkeypatch, backend)
220+
221+
checkpoint_loader = MagicMock(name="checkpoint_loader")
222+
checkpoint_loader.checkpoint_format = "HF"
223+
224+
def record(event):
225+
def _append(*_args, **_kwargs):
226+
events.append(event)
227+
228+
return _append
229+
230+
checkpoint_loader.post_load_apply.side_effect = record("post_load_apply")
231+
checkpoint_loader.post_load_publish.side_effect = record("post_load_publish")
232+
233+
# The STRICT pre-materialize identity gate runs between alias setup and
234+
# materialization; record it to pin the ordering without exercising the
235+
# comparison logic, which is covered in test_source_identity.py.
236+
monkeypatch.setattr(
237+
model_loader_mod,
238+
"check_weight_sharing_compatibility",
239+
lambda *_args, **_kwargs: events.append("check_source_identity"),
240+
)
241+
242+
loader.load("/ckpt", checkpoint_loader)
243+
244+
assert events == [
245+
"post_load_apply",
246+
"setup_aliases",
247+
"check_source_identity",
248+
"materialize",
249+
"cache_derived_state",
250+
"post_load_publish",
251+
]
252+
assert "post_load_weights" not in events
253+
checkpoint_loader.load_weights.assert_not_called()
254+
backend.materialize_module.assert_called_once()
255+
256+
206257
def test_gms_rw_post_load_runs_inside_pool_before_finalize(monkeypatch):
207258
"""Every step that may allocate or rebind tensors must run inside the GMS pool.
208259

0 commit comments

Comments
 (0)