Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions tensorrt_llm/_torch/memory/gpu_memory_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,16 @@
CUDA memory pool. After loading, weights are committed for read-only
access by other workers and the client transitions to RO mode in place.
- **RO (Read-Only)**: Subsequent workers zero-copy import already-committed
weights from the GMS pool. `post_load_weights()` must run BEFORE
materialization so that module aliases are set up correctly.
weights from the GMS pool. `setup_aliases()` must run BEFORE
materialization so that module aliases are set up correctly, while derived
state is refreshed after real tensors are bound. RO is validated for models
whose `post_load_weights()` is pure alias wiring; models that additionally
rely on plain Python attributes set inside `post_load_weights()` (rather
than registered `nn.Buffer` / `nn.Parameter` assignments) need to migrate
those side effects to `cache_derived_state()` or another path that runs on
RO readers. The GMS RO reader runs `setup_aliases()` before
`materialize_module()` and `cache_derived_state()` afterward; it does not
run `transform_weights()`.
"""

from contextlib import contextmanager
Expand Down Expand Up @@ -477,7 +485,7 @@ def materialize_module(self, model: nn.Module) -> None:
by GPU pointers from the shared memory region — no data copies,
no disk I/O, just CUDA VMM remapping. The model's submodule
layout must already match the writer's at commit time, including
any aliases / derived buffers introduced by `post_load_weights`.
any aliases introduced by `setup_aliases`.

Args:
model: The `nn.Module` to materialize. Walks the full
Expand All @@ -489,11 +497,10 @@ def materialize_module(self, model: nn.Module) -> None:
RuntimeError: If `connect()` has not been called yet.

Note:
`post_load_weights()` must be called on the model BEFORE
this method. The order ensures that any aliases / derived
parameters created by post-load hooks are present on the
module tree at materialization time, so they are bound to
the same GMS storage as their primary tensor.
`setup_aliases()` must be called on the model BEFORE this method.
The order ensures that any structural aliases created by post-load
hooks are present on the module tree at materialization time, so
they are bound to the same GMS storage as their primary tensor.
"""
if self._client is None:
raise RuntimeError("GMS client not connected. Call connect() first.")
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/models/modeling_deepseekv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1921,7 +1921,7 @@ def load_weights(self, weights: ConsumableWeightsDict):
weight_loader = DeepseekV3WeightLoader(self)
weight_loader.load_weights(weights)

def post_load_weights(self):
def setup_aliases(self) -> None:
for idx, layer in enumerate(
self.model.layers[:self.config.num_hidden_layers]):
if idx == self.config.num_hidden_layers - 1:
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/models/modeling_exaone_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,7 @@ def load_weights(
allow_partial_loading=allow_partial_loading,
)

def post_load_weights(self):
def setup_aliases(self) -> None:
# For the cross-layer residual+LN fusion.
for idx, layer in enumerate(self.model.layers[: self.config.num_hidden_layers]):
if idx == self.config.num_hidden_layers - 1:
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/models/modeling_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,7 +1074,7 @@ def load_weights(self, weights: ConsumableWeightsDict, allow_partial_loading: bo
weight_loader = Glm4WeightLoader(self)
weight_loader.load_weights(weights, allow_partial_loading=allow_partial_loading)

def post_load_weights(self):
def setup_aliases(self) -> None:
for idx, layer in enumerate(self.model.layers[: self.config.num_hidden_layers]):
if idx == self.config.num_hidden_layers - 1:
layer.next_layer_layernorm = self.model.norm
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/models/modeling_gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ def load_weights(self, weights: Dict):
else:
self.load_hf_weights(weights)

def post_load_weights(self):
def setup_aliases(self) -> None:
for idx, layer in enumerate(
self.model.block[:self.config.num_hidden_layers]):
if idx == 0:
Expand Down
12 changes: 6 additions & 6 deletions tensorrt_llm/_torch/models/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ def __init__(
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
# When post_load_weights() chains layernorms across layers,
# When setup_aliases() chains layernorms across layers,
# this flag is set to True to skip the input layernorm in
# forward() since it is handled by the previous layer.
self.skip_input_layernorm = False
Expand Down Expand Up @@ -709,7 +709,7 @@ def __init__(
quantize_type="nvfp4"
if not self.disable_nvfp4_layernorm_fusion and self.is_nvfp4
and not (differ_pp_stage_with_previous_layer) else None)
# When post_load_weights() chains layernorms across layers,
# When setup_aliases() chains layernorms across layers,
# this flag is set to True to skip the input layernorm in
# forward() since it is handled by the previous layer.
self.skip_input_layernorm = False
Expand Down Expand Up @@ -983,7 +983,7 @@ def __init__(self, model_config: ModelConfig[LlamaConfig]):
self.norm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
# When post_load_weights() chains the final norm into the
# When setup_aliases() chains the final norm into the
# last decoder layer, this flag is set to True to skip
# applying it again in forward().
self.skip_norm = False
Expand Down Expand Up @@ -1088,7 +1088,7 @@ def __init__(self, model_config: ModelConfig[LlamaConfig]):
self.norm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
# When post_load_weights() chains the final norm into the
# When setup_aliases() chains the final norm into the
# last decoder layer, this flag is set to True to skip
# applying it again in forward().
self.skip_norm = False
Expand Down Expand Up @@ -1140,7 +1140,7 @@ def __init__(
):
super().__init__(LlamaModel(model_config), model_config)

def post_load_weights(self):
def setup_aliases(self) -> None:
for idx, layer in enumerate(
self.model.layers[:self.config.num_hidden_layers]):
if idx == self.config.num_hidden_layers - 1:
Expand Down Expand Up @@ -1564,7 +1564,7 @@ def load_weights(self, weights: Dict, weight_mapper: BaseWeightMapper):
if had_mm_encoder:
self.mm_encoder = saved_mm_encoder

def post_load_weights(self):
def setup_aliases(self) -> None:
for idx, layer in enumerate(
self.model.layers[:self.config.num_hidden_layers]):
if idx == self.config.num_hidden_layers - 1:
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/models/modeling_qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ def __init__(
)
self.preload_weight_modules = self.model.preload_weight_modules

def post_load_weights(self):
def setup_aliases(self) -> None:
for idx, layer in enumerate(
self.model.layers[:self.config.num_hidden_layers]):
if idx == self.config.num_hidden_layers - 1:
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/models/modeling_qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,7 +980,7 @@ def load_weights(self, weights: dict, weight_mapper: BaseWeightMapper):
new_weights = weight_mapper.preprocess_weights(weights)
super().load_weights(new_weights, weight_mapper)

def post_load_weights(self):
def setup_aliases(self) -> None:
for idx, layer in enumerate(
self.model.layers[:self.config.num_hidden_layers]):
if idx == self.config.num_hidden_layers - 1:
Expand Down
58 changes: 34 additions & 24 deletions tensorrt_llm/_torch/pyexecutor/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,8 +485,9 @@ def init_meta_tensor(t: torch.Tensor):
# post_load_* hooks itself, so the shared post-load block below
# must skip them. RW handles them inside `mem_pool_scope` so the
# committed pool reflects the post-post_load layout; RO runs
# `module.post_load_weights()` before `materialize_module` to
# wire aliases prior to zero-copy mapping.
# `setup_aliases()` before `materialize_module` to wire aliases
# prior to zero-copy mapping, then refreshes derived state after
# real GMS tensors are bound.
gms_post_load_handled = False
if load_format == LoadFormat.AUTO:
# Pass model= so format-specific loaders (e.g. MX) can
Expand Down Expand Up @@ -717,29 +718,31 @@ def init_meta_tensor_in_pool(t: torch.Tensor):
# Hook order:
# 1. `post_load_apply`: format-specific apply
# work (e.g., MX preshard markers).
# 2. Per-module `post_load_weights`: creates
# aliases/derived parameter attributes BEFORE
# `materialize_module` walks the final module
# tree (including `draft_model` for spec dec).
# 3. `materialize_module`: zero-copy bind GMS
# 2. Per-module `setup_aliases`: creates structural
# aliases BEFORE `materialize_module` walks the
# final module tree (including `draft_model` for
# spec dec).
# 3. SourceIdentity gate: STRICT pre-materialize
# compatibility check (GMS has no disk fallback).
# 4. `materialize_module`: zero-copy bind GMS
# pool storage onto the model parameters.
# 4. `post_load_publish`: any receiver-side
# 5. Per-module `cache_derived_state`: recompute
# Python-side state from real, materialized
# tensors without re-running one-shot transforms.
# 6. `post_load_publish`: any receiver-side
# publish (no-op via the receiver guard).
checkpoint_loader.post_load_apply(
model, weights_preloaded=True)

for module in model.modules():
if hasattr(module,
'post_load_weights') and not getattr(
module, '_weights_removed', False):
module.post_load_weights()
self._setup_aliases(model)

# Pre-materialize compatibility gate. GMS has no
# disk-fallback path, so a mismatch raises under STRICT
# rather than falling back.
self._check_gms_source_identity(gms_backend)

gms_backend.materialize_module(model)
self._walk_cache_state(model)

checkpoint_loader.post_load_publish(
model,
Expand Down Expand Up @@ -829,22 +832,24 @@ def _check_gms_source_identity(self, gms_backend) -> None:

@staticmethod
def _setup_aliases(model: DecoderModelForCausalLM) -> None:
"""Run top-level structural alias setup if the model defines it.
"""Run structural alias setup on eligible modules.

Alias wiring is a model-level concern. It is intentionally not a
recursive module walk, because migrated aliases are expected to be set
by the root model that owns the layer graph.
The walk is duck-typed so modules can opt in without inheriting a
shared base class. Modules whose weights were removed are skipped,
matching the legacy full post-load walk.

Args:
model: Root decoder model whose top-level alias hook should run.
model: Root decoder model whose module tree should be visited.

Returns:
None.
"""
setup_aliases: Optional[Callable[[], None]] = getattr(
model, 'setup_aliases', None)
if setup_aliases is not None:
setup_aliases()
for module in model.modules():
setup_aliases: Optional[Callable[[], None]] = getattr(
module, 'setup_aliases', None)
if setup_aliases is not None and not getattr(
module, '_weights_removed', False):
setup_aliases()

@staticmethod
def _walk_transform(model: DecoderModelForCausalLM) -> None:
Expand Down Expand Up @@ -935,8 +940,11 @@ def reload(self,
"""Reload model weights without running post-load hooks.

Reload is used by incremental update paths that may provide only a
partial set of replacement weights. The owner of the update lifecycle is
responsible for running post-load processing once all bytes are present.
partial set of replacement weights. Full reloads reset transform guards
before rebinding fresh weights. Partial reloads keep existing transform
guards intact because untouched modules may already contain transformed
live weights. The owner of the update lifecycle is responsible for
running post-load processing once all bytes are present.

Args:
model: Model instance receiving the replacement weights.
Expand All @@ -952,6 +960,8 @@ def reload(self,
"Cannot reload weights: weight_mapper was not initialized. "
"This can happen when the initial load used GMS, MX P2P, or "
"VISION_ONLY, which bypass the standard weight mapping path.")
if not allow_partial_loading:
self._reset_weights_transformed(model)
self._call_load_weights(model.load_weights,
weights,
self.weight_mapper,
Expand Down
65 changes: 58 additions & 7 deletions tests/unittest/_torch/pyexecutor/test_model_loader_gms.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,13 @@ def to(self, *args, **kwargs):
def load_weights(self, weights, mapper):
self._events.append("load_weights")

def post_load_weights(self):
def setup_aliases(self) -> None:
self._events.append("setup_aliases")

def cache_derived_state(self) -> None:
self._events.append("cache_derived_state")

def post_load_weights(self) -> None:
self._events.append("post_load_weights")


Expand Down Expand Up @@ -76,7 +82,9 @@ def _make_loader(monkeypatch, *, events, spec_config=None):
loader._call_load_weights = MagicMock(
side_effect=lambda fn, weights, mapper, **kwargs: fn(weights, mapper)
)
loader._load_and_validate_config = MagicMock(return_value=SimpleNamespace(name="config"))
loader._load_and_validate_config = MagicMock(
return_value=SimpleNamespace(name="config", mapping=SimpleNamespace())
)

monkeypatch.setattr(model_loader_mod, "timing", lambda *_args, **_kwargs: nullcontext())
monkeypatch.setattr(model_loader_mod, "maybe_create_moe_load_balancer", _moe_context)
Expand Down Expand Up @@ -147,7 +155,7 @@ def _spec_config_needing_draft_weights():
),
pytest.param(
False,
["post_load_weights", "materialize"],
["setup_aliases", "materialize", "cache_derived_state"],
id="ro",
),
],
Expand All @@ -161,8 +169,9 @@ def test_gms_load_branch(monkeypatch, is_rw, expected_events):
(``_apply`` for meta materialization, ``to('cuda')``, weight
load, ``post_load_weights``) inside the pool, then commits via
``finalize_write`` once the scope exits.
ro: the reader runs ``post_load_weights`` to wire module aliases
first, then GMS materializes weights via zero-copy mapping.
ro: the reader runs ``setup_aliases`` to wire module aliases, checks
identity compatibility, materializes weights via zero-copy mapping,
then refreshes derived state from real tensors.
"""
events = []
loader = _make_loader(monkeypatch, events=events)
Expand Down Expand Up @@ -196,13 +205,55 @@ def test_gms_load_branch(monkeypatch, is_rw, expected_events):
backend.move_untracked_params.assert_called_once_with(model)
backend.finalize_write.assert_called_once_with(model)
else:
# RO: post_load_weights() must run before the GMS materialize
# step so module aliases are wired up before zero-copy mapping.
# RO: setup_aliases() must run before the GMS materialize step so
# module aliases are wired up before zero-copy mapping.
checkpoint_loader.load_weights.assert_not_called()
loader._call_load_weights.assert_not_called()
backend.materialize_module.assert_called_once_with(model)


def test_gms_ro_materializes_between_alias_setup_and_cache_state(monkeypatch):
events = []
loader = _make_loader(monkeypatch, events=events)
backend = _build_gms_backend(is_rw=False, events=events)
_install_gms_backend(monkeypatch, backend)

checkpoint_loader = MagicMock(name="checkpoint_loader")
checkpoint_loader.checkpoint_format = "HF"

def record(event):
def _append(*_args, **_kwargs):
events.append(event)

return _append

checkpoint_loader.post_load_apply.side_effect = record("post_load_apply")
checkpoint_loader.post_load_publish.side_effect = record("post_load_publish")

# The STRICT pre-materialize identity gate runs between alias setup and
# materialization; record it to pin the ordering without exercising the
# comparison logic, which is covered in test_source_identity.py.
monkeypatch.setattr(
model_loader_mod,
"check_weight_sharing_compatibility",
lambda *_args, **_kwargs: events.append("check_source_identity"),
)

loader.load("/ckpt", checkpoint_loader)

assert events == [
"post_load_apply",
"setup_aliases",
"check_source_identity",
"materialize",
"cache_derived_state",
"post_load_publish",
]
assert "post_load_weights" not in events
checkpoint_loader.load_weights.assert_not_called()
backend.materialize_module.assert_called_once()


def test_gms_rw_post_load_runs_inside_pool_before_finalize(monkeypatch):
"""Every step that may allocate or rebind tensors must run inside the GMS pool.

Expand Down
Loading
Loading