diff --git a/tensorrt_llm/_torch/memory/gpu_memory_backend.py b/tensorrt_llm/_torch/memory/gpu_memory_backend.py index 6c39d2deccea..12edcd071998 100644 --- a/tensorrt_llm/_torch/memory/gpu_memory_backend.py +++ b/tensorrt_llm/_torch/memory/gpu_memory_backend.py @@ -36,8 +36,16 @@ CUDA memory pool. After loading, weights are committed for read-only access by other workers and the client transitions to RO mode in place. - **RO (Read-Only)**: Subsequent workers zero-copy import already-committed - weights from the GMS pool. `post_load_weights()` must run BEFORE - materialization so that module aliases are set up correctly. + weights from the GMS pool. `setup_aliases()` must run BEFORE + materialization so that module aliases are set up correctly, while derived + state is refreshed after real tensors are bound. RO is validated for models + whose `post_load_weights()` is pure alias wiring; models that additionally + rely on plain Python attributes set inside `post_load_weights()` (rather + than registered `nn.Buffer` / `nn.Parameter` assignments) need to migrate + those side effects to `cache_derived_state()` or another path that runs on + RO readers. The GMS RO reader runs `setup_aliases()` before + `materialize_module()` and `cache_derived_state()` afterward; it does not + run `transform_weights()`. """ from contextlib import contextmanager @@ -477,7 +485,7 @@ def materialize_module(self, model: nn.Module) -> None: by GPU pointers from the shared memory region — no data copies, no disk I/O, just CUDA VMM remapping. The model's submodule layout must already match the writer's at commit time, including - any aliases / derived buffers introduced by `post_load_weights`. + any aliases introduced by `setup_aliases`. Args: model: The `nn.Module` to materialize. Walks the full @@ -489,11 +497,10 @@ def materialize_module(self, model: nn.Module) -> None: RuntimeError: If `connect()` has not been called yet. Note: - `post_load_weights()` must be called on the model BEFORE - this method. The order ensures that any aliases / derived - parameters created by post-load hooks are present on the - module tree at materialization time, so they are bound to - the same GMS storage as their primary tensor. + `setup_aliases()` must be called on the model BEFORE this method. + The order ensures that any structural aliases created by post-load + hooks are present on the module tree at materialization time, so + they are bound to the same GMS storage as their primary tensor. """ if self._client is None: raise RuntimeError("GMS client not connected. Call connect() first.") diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index 6e2b4b532f49..9cf5f716dda1 100755 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -1921,7 +1921,7 @@ def load_weights(self, weights: ConsumableWeightsDict): weight_loader = DeepseekV3WeightLoader(self) weight_loader.load_weights(weights) - def post_load_weights(self): + def setup_aliases(self) -> None: for idx, layer in enumerate( self.model.layers[:self.config.num_hidden_layers]): if idx == self.config.num_hidden_layers - 1: diff --git a/tensorrt_llm/_torch/models/modeling_exaone_moe.py b/tensorrt_llm/_torch/models/modeling_exaone_moe.py index ba8577da9613..40ae3653d6e0 100644 --- a/tensorrt_llm/_torch/models/modeling_exaone_moe.py +++ b/tensorrt_llm/_torch/models/modeling_exaone_moe.py @@ -725,7 +725,7 @@ def load_weights( allow_partial_loading=allow_partial_loading, ) - def post_load_weights(self): + def setup_aliases(self) -> None: # For the cross-layer residual+LN fusion. for idx, layer in enumerate(self.model.layers[: self.config.num_hidden_layers]): if idx == self.config.num_hidden_layers - 1: diff --git a/tensorrt_llm/_torch/models/modeling_glm.py b/tensorrt_llm/_torch/models/modeling_glm.py index 293510b65099..2572ea548e48 100644 --- a/tensorrt_llm/_torch/models/modeling_glm.py +++ b/tensorrt_llm/_torch/models/modeling_glm.py @@ -1074,7 +1074,7 @@ def load_weights(self, weights: ConsumableWeightsDict, allow_partial_loading: bo weight_loader = Glm4WeightLoader(self) weight_loader.load_weights(weights, allow_partial_loading=allow_partial_loading) - def post_load_weights(self): + def setup_aliases(self) -> None: for idx, layer in enumerate(self.model.layers[: self.config.num_hidden_layers]): if idx == self.config.num_hidden_layers - 1: layer.next_layer_layernorm = self.model.norm diff --git a/tensorrt_llm/_torch/models/modeling_gpt_oss.py b/tensorrt_llm/_torch/models/modeling_gpt_oss.py index bc908a2ec014..00b5c77c8951 100644 --- a/tensorrt_llm/_torch/models/modeling_gpt_oss.py +++ b/tensorrt_llm/_torch/models/modeling_gpt_oss.py @@ -631,7 +631,7 @@ def load_weights(self, weights: Dict): else: self.load_hf_weights(weights) - def post_load_weights(self): + def setup_aliases(self) -> None: for idx, layer in enumerate( self.model.block[:self.config.num_hidden_layers]): if idx == 0: diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index d77c22322de9..1150c806de75 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -484,7 +484,7 @@ def __init__( self.input_layernorm = RMSNorm(hidden_size=config.hidden_size, eps=config.rms_norm_eps, dtype=config.torch_dtype) - # When post_load_weights() chains layernorms across layers, + # When setup_aliases() chains layernorms across layers, # this flag is set to True to skip the input layernorm in # forward() since it is handled by the previous layer. self.skip_input_layernorm = False @@ -709,7 +709,7 @@ def __init__( quantize_type="nvfp4" if not self.disable_nvfp4_layernorm_fusion and self.is_nvfp4 and not (differ_pp_stage_with_previous_layer) else None) - # When post_load_weights() chains layernorms across layers, + # When setup_aliases() chains layernorms across layers, # this flag is set to True to skip the input layernorm in # forward() since it is handled by the previous layer. self.skip_input_layernorm = False @@ -983,7 +983,7 @@ def __init__(self, model_config: ModelConfig[LlamaConfig]): self.norm = RMSNorm(hidden_size=config.hidden_size, eps=config.rms_norm_eps, dtype=config.torch_dtype) - # When post_load_weights() chains the final norm into the + # When setup_aliases() chains the final norm into the # last decoder layer, this flag is set to True to skip # applying it again in forward(). self.skip_norm = False @@ -1088,7 +1088,7 @@ def __init__(self, model_config: ModelConfig[LlamaConfig]): self.norm = RMSNorm(hidden_size=config.hidden_size, eps=config.rms_norm_eps, dtype=config.torch_dtype) - # When post_load_weights() chains the final norm into the + # When setup_aliases() chains the final norm into the # last decoder layer, this flag is set to True to skip # applying it again in forward(). self.skip_norm = False @@ -1140,7 +1140,7 @@ def __init__( ): super().__init__(LlamaModel(model_config), model_config) - def post_load_weights(self): + def setup_aliases(self) -> None: for idx, layer in enumerate( self.model.layers[:self.config.num_hidden_layers]): if idx == self.config.num_hidden_layers - 1: @@ -1564,7 +1564,7 @@ def load_weights(self, weights: Dict, weight_mapper: BaseWeightMapper): if had_mm_encoder: self.mm_encoder = saved_mm_encoder - def post_load_weights(self): + def setup_aliases(self) -> None: for idx, layer in enumerate( self.model.layers[:self.config.num_hidden_layers]): if idx == self.config.num_hidden_layers - 1: diff --git a/tensorrt_llm/_torch/models/modeling_qwen3_moe.py b/tensorrt_llm/_torch/models/modeling_qwen3_moe.py index 43b4499f4874..571e3fe503c0 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen3_moe.py +++ b/tensorrt_llm/_torch/models/modeling_qwen3_moe.py @@ -417,7 +417,7 @@ def __init__( ) self.preload_weight_modules = self.model.preload_weight_modules - def post_load_weights(self): + def setup_aliases(self) -> None: for idx, layer in enumerate( self.model.layers[:self.config.num_hidden_layers]): if idx == self.config.num_hidden_layers - 1: diff --git a/tensorrt_llm/_torch/models/modeling_qwen3_next.py b/tensorrt_llm/_torch/models/modeling_qwen3_next.py index 345f8eacee6e..7667972804ad 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen3_next.py +++ b/tensorrt_llm/_torch/models/modeling_qwen3_next.py @@ -980,7 +980,7 @@ def load_weights(self, weights: dict, weight_mapper: BaseWeightMapper): new_weights = weight_mapper.preprocess_weights(weights) super().load_weights(new_weights, weight_mapper) - def post_load_weights(self): + def setup_aliases(self) -> None: for idx, layer in enumerate( self.model.layers[:self.config.num_hidden_layers]): if idx == self.config.num_hidden_layers - 1: diff --git a/tensorrt_llm/_torch/pyexecutor/model_loader.py b/tensorrt_llm/_torch/pyexecutor/model_loader.py index 6df540b88487..7aed2da43cdf 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_loader.py +++ b/tensorrt_llm/_torch/pyexecutor/model_loader.py @@ -485,8 +485,9 @@ def init_meta_tensor(t: torch.Tensor): # post_load_* hooks itself, so the shared post-load block below # must skip them. RW handles them inside `mem_pool_scope` so the # committed pool reflects the post-post_load layout; RO runs - # `module.post_load_weights()` before `materialize_module` to - # wire aliases prior to zero-copy mapping. + # `setup_aliases()` before `materialize_module` to wire aliases + # prior to zero-copy mapping, then refreshes derived state after + # real GMS tensors are bound. gms_post_load_handled = False if load_format == LoadFormat.AUTO: # Pass model= so format-specific loaders (e.g. MX) can @@ -717,22 +718,23 @@ def init_meta_tensor_in_pool(t: torch.Tensor): # Hook order: # 1. `post_load_apply`: format-specific apply # work (e.g., MX preshard markers). - # 2. Per-module `post_load_weights`: creates - # aliases/derived parameter attributes BEFORE - # `materialize_module` walks the final module - # tree (including `draft_model` for spec dec). - # 3. `materialize_module`: zero-copy bind GMS + # 2. Per-module `setup_aliases`: creates structural + # aliases BEFORE `materialize_module` walks the + # final module tree (including `draft_model` for + # spec dec). + # 3. SourceIdentity gate: STRICT pre-materialize + # compatibility check (GMS has no disk fallback). + # 4. `materialize_module`: zero-copy bind GMS # pool storage onto the model parameters. - # 4. `post_load_publish`: any receiver-side + # 5. Per-module `cache_derived_state`: recompute + # Python-side state from real, materialized + # tensors without re-running one-shot transforms. + # 6. `post_load_publish`: any receiver-side # publish (no-op via the receiver guard). checkpoint_loader.post_load_apply( model, weights_preloaded=True) - for module in model.modules(): - if hasattr(module, - 'post_load_weights') and not getattr( - module, '_weights_removed', False): - module.post_load_weights() + self._setup_aliases(model) # Pre-materialize compatibility gate. GMS has no # disk-fallback path, so a mismatch raises under STRICT @@ -740,6 +742,7 @@ def init_meta_tensor_in_pool(t: torch.Tensor): self._check_gms_source_identity(gms_backend) gms_backend.materialize_module(model) + self._walk_cache_state(model) checkpoint_loader.post_load_publish( model, @@ -829,22 +832,24 @@ def _check_gms_source_identity(self, gms_backend) -> None: @staticmethod def _setup_aliases(model: DecoderModelForCausalLM) -> None: - """Run top-level structural alias setup if the model defines it. + """Run structural alias setup on eligible modules. - Alias wiring is a model-level concern. It is intentionally not a - recursive module walk, because migrated aliases are expected to be set - by the root model that owns the layer graph. + The walk is duck-typed so modules can opt in without inheriting a + shared base class. Modules whose weights were removed are skipped, + matching the legacy full post-load walk. Args: - model: Root decoder model whose top-level alias hook should run. + model: Root decoder model whose module tree should be visited. Returns: None. """ - setup_aliases: Optional[Callable[[], None]] = getattr( - model, 'setup_aliases', None) - if setup_aliases is not None: - setup_aliases() + for module in model.modules(): + setup_aliases: Optional[Callable[[], None]] = getattr( + module, 'setup_aliases', None) + if setup_aliases is not None and not getattr( + module, '_weights_removed', False): + setup_aliases() @staticmethod def _walk_transform(model: DecoderModelForCausalLM) -> None: @@ -935,8 +940,11 @@ def reload(self, """Reload model weights without running post-load hooks. Reload is used by incremental update paths that may provide only a - partial set of replacement weights. The owner of the update lifecycle is - responsible for running post-load processing once all bytes are present. + partial set of replacement weights. Full reloads reset transform guards + before rebinding fresh weights. Partial reloads keep existing transform + guards intact because untouched modules may already contain transformed + live weights. The owner of the update lifecycle is responsible for + running post-load processing once all bytes are present. Args: model: Model instance receiving the replacement weights. @@ -952,6 +960,8 @@ def reload(self, "Cannot reload weights: weight_mapper was not initialized. " "This can happen when the initial load used GMS, MX P2P, or " "VISION_ONLY, which bypass the standard weight mapping path.") + if not allow_partial_loading: + self._reset_weights_transformed(model) self._call_load_weights(model.load_weights, weights, self.weight_mapper, diff --git a/tests/unittest/_torch/pyexecutor/test_model_loader_gms.py b/tests/unittest/_torch/pyexecutor/test_model_loader_gms.py index f429f1c5f4d7..3fcc353a9fa2 100644 --- a/tests/unittest/_torch/pyexecutor/test_model_loader_gms.py +++ b/tests/unittest/_torch/pyexecutor/test_model_loader_gms.py @@ -44,7 +44,13 @@ def to(self, *args, **kwargs): def load_weights(self, weights, mapper): self._events.append("load_weights") - def post_load_weights(self): + def setup_aliases(self) -> None: + self._events.append("setup_aliases") + + def cache_derived_state(self) -> None: + self._events.append("cache_derived_state") + + def post_load_weights(self) -> None: self._events.append("post_load_weights") @@ -76,7 +82,9 @@ def _make_loader(monkeypatch, *, events, spec_config=None): loader._call_load_weights = MagicMock( side_effect=lambda fn, weights, mapper, **kwargs: fn(weights, mapper) ) - loader._load_and_validate_config = MagicMock(return_value=SimpleNamespace(name="config")) + loader._load_and_validate_config = MagicMock( + return_value=SimpleNamespace(name="config", mapping=SimpleNamespace()) + ) monkeypatch.setattr(model_loader_mod, "timing", lambda *_args, **_kwargs: nullcontext()) monkeypatch.setattr(model_loader_mod, "maybe_create_moe_load_balancer", _moe_context) @@ -147,7 +155,7 @@ def _spec_config_needing_draft_weights(): ), pytest.param( False, - ["post_load_weights", "materialize"], + ["setup_aliases", "materialize", "cache_derived_state"], id="ro", ), ], @@ -161,8 +169,9 @@ def test_gms_load_branch(monkeypatch, is_rw, expected_events): (``_apply`` for meta materialization, ``to('cuda')``, weight load, ``post_load_weights``) inside the pool, then commits via ``finalize_write`` once the scope exits. - ro: the reader runs ``post_load_weights`` to wire module aliases - first, then GMS materializes weights via zero-copy mapping. + ro: the reader runs ``setup_aliases`` to wire module aliases, checks + identity compatibility, materializes weights via zero-copy mapping, + then refreshes derived state from real tensors. """ events = [] loader = _make_loader(monkeypatch, events=events) @@ -196,13 +205,55 @@ def test_gms_load_branch(monkeypatch, is_rw, expected_events): backend.move_untracked_params.assert_called_once_with(model) backend.finalize_write.assert_called_once_with(model) else: - # RO: post_load_weights() must run before the GMS materialize - # step so module aliases are wired up before zero-copy mapping. + # RO: setup_aliases() must run before the GMS materialize step so + # module aliases are wired up before zero-copy mapping. checkpoint_loader.load_weights.assert_not_called() loader._call_load_weights.assert_not_called() backend.materialize_module.assert_called_once_with(model) +def test_gms_ro_materializes_between_alias_setup_and_cache_state(monkeypatch): + events = [] + loader = _make_loader(monkeypatch, events=events) + backend = _build_gms_backend(is_rw=False, events=events) + _install_gms_backend(monkeypatch, backend) + + checkpoint_loader = MagicMock(name="checkpoint_loader") + checkpoint_loader.checkpoint_format = "HF" + + def record(event): + def _append(*_args, **_kwargs): + events.append(event) + + return _append + + checkpoint_loader.post_load_apply.side_effect = record("post_load_apply") + checkpoint_loader.post_load_publish.side_effect = record("post_load_publish") + + # The STRICT pre-materialize identity gate runs between alias setup and + # materialization; record it to pin the ordering without exercising the + # comparison logic, which is covered in test_source_identity.py. + monkeypatch.setattr( + model_loader_mod, + "check_weight_sharing_compatibility", + lambda *_args, **_kwargs: events.append("check_source_identity"), + ) + + loader.load("/ckpt", checkpoint_loader) + + assert events == [ + "post_load_apply", + "setup_aliases", + "check_source_identity", + "materialize", + "cache_derived_state", + "post_load_publish", + ] + assert "post_load_weights" not in events + checkpoint_loader.load_weights.assert_not_called() + backend.materialize_module.assert_called_once() + + def test_gms_rw_post_load_runs_inside_pool_before_finalize(monkeypatch): """Every step that may allocate or rebind tensors must run inside the GMS pool. diff --git a/tests/unittest/_torch/pyexecutor/test_model_loader_mx.py b/tests/unittest/_torch/pyexecutor/test_model_loader_mx.py index 0ffe6fb32a2a..c110cb1a6204 100644 --- a/tests/unittest/_torch/pyexecutor/test_model_loader_mx.py +++ b/tests/unittest/_torch/pyexecutor/test_model_loader_mx.py @@ -132,11 +132,32 @@ def test_mx_success_initializes_mapper_skips_weight_mapping_and_reload_works(mon # reload() uses self.weight_mapper unconditionally; MX success must # initialize it even though the initial load skipped _call_load_weights. + model._weights_transformed = True + model.linear._weights_transformed = True loader.reload(model, {"reloaded": MagicMock()}) assert loader._call_load_weights.call_count == 1 + assert model._weights_transformed is False + assert model.linear._weights_transformed is False assert events == ["post_load_weights", "load_weights"] +def test_reload_partial_loading_preserves_weights_transformed_flags(monkeypatch): + events = [] + loader = _make_loader(monkeypatch, events=events) + loader.weight_mapper = MagicMock(name="weight_mapper") + model = _TinyModel(events) + model._weights_transformed = True + model.linear._weights_transformed = True + + loader.reload(model, {"reloaded": MagicMock()}, allow_partial_loading=True) + + assert loader._call_load_weights.call_count == 1 + assert loader._call_load_weights.call_args.kwargs["allow_partial_loading"] is True + assert model._weights_transformed is True + assert model.linear._weights_transformed is True + assert events == ["load_weights"] + + def test_mx_partial_fallback_merges_returned_weights(monkeypatch): events = [] loader = _make_loader(monkeypatch, events=events) @@ -194,17 +215,17 @@ def __init__( if transformed is not None: self._weights_transformed = transformed - def setup_aliases(self): + def setup_aliases(self) -> None: self.events.append((self.name, "setup_aliases")) - def transform_weights(self): + def transform_weights(self) -> None: self.events.append((self.name, "transform_weights")) self._weights_transformed = True - def cache_derived_state(self): + def cache_derived_state(self) -> None: self.events.append((self.name, "cache_derived_state")) - def post_load_weights(self): + def post_load_weights(self) -> None: self.events.append((self.name, "post_load_weights")) @@ -216,13 +237,17 @@ def __init__(self, events): self.removed_child = _HookRecorder("removed_child", events, removed=True) -def test_staged_hook_setup_aliases_is_top_level_only(): +def test_staged_hook_setup_aliases_walks_skip_removed_modules(): events = [] model = _HookModel(events) ModelLoader._setup_aliases(model) - assert events == [("model", "setup_aliases")] + assert events == [ + ("model", "setup_aliases"), + ("child", "setup_aliases"), + ("transformed_child", "setup_aliases"), + ] def test_staged_hook_walks_skip_removed_and_transformed_modules():