diff --git a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py index 0ee3cb4b8f03..4f7cc07618de 100644 --- a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py +++ b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py @@ -1,3 +1,17 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Dense Sparse Attention (DSA) backend for TRT-LLM with indexer-based TopK selection.""" import math import threading @@ -1504,13 +1518,16 @@ def __init__(self, # attribute queries do not end up frozen into a captured graph. warmup_heuristic_topk_decode(top_k=self.index_topk) - def post_load_weights(self): + def cache_derived_state(self) -> None: """Fuse wk + weights_proj into single FP32 weight for F.linear GEMM under allow_tf32 (TF32 tensor cores on Ampere+).""" # wk: [head_dim, hidden_size] + weights_proj: [n_heads, hidden_size] # → fused: [head_dim + n_heads, hidden_size] self._fused_wk_wp_weight = torch.cat( [self.wk.weight.data, self.weights_proj.weight.data], dim=0) + def post_load_weights(self) -> None: + self.cache_derived_state() + @staticmethod def prepare_one_prefill_chunk( metadata: DSAtrtllmAttentionMetadata, @@ -2404,7 +2421,7 @@ def pre_indexer_proj( split in MLA.forward_dsa_proj sees a stable signature. """ assert self._fused_wk_wp_weight is not None, \ - "post_load_weights() must be called before forward()" + "cache_derived_state() must be called before forward()" hidden_float = _to_float(hidden_states) with _tf32_matmul_enabled(): # F.linear computes input @ weight.T internally; no explicit .t() needed. diff --git a/tensorrt_llm/_torch/memory/gpu_memory_backend.py b/tensorrt_llm/_torch/memory/gpu_memory_backend.py index 6c39d2deccea..153f0f9c5c79 100644 --- a/tensorrt_llm/_torch/memory/gpu_memory_backend.py +++ b/tensorrt_llm/_torch/memory/gpu_memory_backend.py @@ -36,8 +36,17 @@ CUDA memory pool. After loading, weights are committed for read-only access by other workers and the client transitions to RO mode in place. - **RO (Read-Only)**: Subsequent workers zero-copy import already-committed - weights from the GMS pool. `post_load_weights()` must run BEFORE - materialization so that module aliases are set up correctly. + weights from the GMS pool. `setup_aliases()` must run BEFORE + materialization so that module aliases are set up correctly, while derived + state is refreshed after real tensors are bound. RO is validated for models + whose `post_load_weights()` is pure alias wiring; models that additionally + rely on plain Python attributes set inside `post_load_weights()` (rather + than registered `nn.Buffer` / `nn.Parameter` assignments) need to migrate + those side effects to `cache_derived_state()` or another hook that runs on + RO readers. One-shot tensor layout changes belong in `transform_weights()` + on the writer; the GMS RO reader runs `setup_aliases()` before + `materialize_module()`, then `cache_derived_state()` afterward. It does not + run `transform_weights()`. """ from contextlib import contextmanager @@ -477,7 +486,7 @@ def materialize_module(self, model: nn.Module) -> None: by GPU pointers from the shared memory region — no data copies, no disk I/O, just CUDA VMM remapping. The model's submodule layout must already match the writer's at commit time, including - any aliases / derived buffers introduced by `post_load_weights`. + any aliases introduced by `setup_aliases`. Args: model: The `nn.Module` to materialize. Walks the full @@ -489,11 +498,10 @@ def materialize_module(self, model: nn.Module) -> None: RuntimeError: If `connect()` has not been called yet. Note: - `post_load_weights()` must be called on the model BEFORE - this method. The order ensures that any aliases / derived - parameters created by post-load hooks are present on the - module tree at materialization time, so they are bound to - the same GMS storage as their primary tensor. + `setup_aliases()` must be called on the model BEFORE this method. + The order ensures that any structural aliases created by post-load + hooks are present on the module tree at materialization time, so + they are bound to the same GMS storage as their primary tensor. """ if self._client is None: raise RuntimeError("GMS client not connected. Call connect() first.") diff --git a/tensorrt_llm/_torch/models/checkpoints/base_checkpoint_loader.py b/tensorrt_llm/_torch/models/checkpoints/base_checkpoint_loader.py index 789f6c565ca3..e1c09a1b6c08 100644 --- a/tensorrt_llm/_torch/models/checkpoints/base_checkpoint_loader.py +++ b/tensorrt_llm/_torch/models/checkpoints/base_checkpoint_loader.py @@ -1,3 +1,6 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + from abc import ABC, abstractmethod from typing import Any @@ -69,6 +72,17 @@ def is_weights_preloaded(self) -> bool: """Whether the last load wrote weights directly into the model.""" return False + def is_post_transform_weights_preloaded(self) -> bool: + """Whether the last direct preload delivered post-transform weights. + + This is narrower than :meth:`is_weights_preloaded`: a loader may write + bytes directly into the model while those bytes are still the raw + checkpoint layout. Only return ``True`` when the source identity was + verified and the incoming bytes can safely skip module + ``transform_weights()`` hooks. + """ + return False + def post_load_apply(self, model: nn.Module, *, diff --git a/tensorrt_llm/_torch/models/checkpoints/mx/checkpoint_loader.py b/tensorrt_llm/_torch/models/checkpoints/mx/checkpoint_loader.py index 4a23e519b8a1..86de1acf90ad 100644 --- a/tensorrt_llm/_torch/models/checkpoints/mx/checkpoint_loader.py +++ b/tensorrt_llm/_torch/models/checkpoints/mx/checkpoint_loader.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """MX (ModelExpress) checkpoint loader. Thin adapter on top of the upstream `modelexpress` Python client @@ -28,10 +27,13 @@ `HfCheckpointLoader` base class. """ +import inspect +import json import os import threading import traceback from contextlib import contextmanager +from enum import Enum from pathlib import Path from typing import Any, Callable, Optional, Type, Union @@ -60,6 +62,17 @@ # Tracked as MX-4 in §15 (non-blocking source-query API upstream). _MX_SOURCE_QUERY_TIMEOUT_DEFAULT_S = "30" _MX_PUBLISH_ENV_LOCK = threading.Lock() +_MX_SOURCE_IDENTITY_METADATA_KEY = "trtllm_source_identity" +_MX_WEIGHT_LAYOUT_METADATA_KEY = "trtllm_weight_layout" +_MX_TRANSFORM_PROTOCOL_VERSION_METADATA_KEY = "trtllm_transform_protocol_version" +_MX_WEIGHT_LAYOUT_POST_TRANSFORM = "post_transform" +_MX_STAGED_TRANSFORM_PROTOCOL_VERSION = 1 + + +class _MxWeightLayoutStatus(Enum): + PRE_TRANSFORM = "pre_transform" + POST_TRANSFORM_SUPPORTED = "post_transform_supported" + UNSUPPORTED = "unsupported" @contextmanager @@ -86,9 +99,8 @@ class MXCheckpointLoader(HfCheckpointLoader): When an MX server is reachable AND the upstream `modelexpress` library is installed, weights are transferred directly from a source instance via NIXL/RDMA, bypassing disk I/O. The source - publishes its weights *before* `post_load_weights()` runs so - targets receive raw loaded state and can run their own - post-load transforms. + publishes its weights after `post_load_weights()` runs, together with + metadata that lets compatible targets skip one-shot post-load transforms. When the MX server or library is unavailable, this loader transparently falls back to standard HuggingFace checkpoint @@ -129,6 +141,8 @@ def __init__( self._model_name = str(model_name) if model_name is not None else None self._query_timeout_s = query_timeout_s self._p2p_succeeded = False + self._post_transform_weights_preloaded = False + self._source_identity_compatible_for_last_load = False # Receiver's local SourceIdentity, supplied per load_weights() call by # ModelLoader; the authority for the pre-transfer compatibility gate. self._local_source_identity: Optional[SourceIdentity] = None @@ -184,6 +198,18 @@ def is_weights_preloaded(self) -> bool: """ return self._p2p_succeeded + def is_post_transform_weights_preloaded(self) -> bool: + """Whether the last successful MX preload delivered transformed bytes. + + The source identity bit is included here so callers have one + conservative signal: no identity match, no transform skip. + """ + return ( + self._p2p_succeeded + and self._post_transform_weights_preloaded + and self._source_identity_compatible_for_last_load + ) + def load_weights(self, checkpoint_dir: str, mapping: Mapping, **kwargs) -> dict[str, Any]: """Load weights, preferring MX P2P transfer when available. @@ -206,7 +232,10 @@ def load_weights(self, checkpoint_dir: str, mapping: Mapping, **kwargs) -> dict[ model = kwargs.pop("model", None) # Popped here so it never leaks into the disk-fallback signature. self._local_source_identity = kwargs.pop("source_identity", None) + allow_post_transform_weights = kwargs.pop("allow_post_transform_weights", False) self._p2p_succeeded = False + self._post_transform_weights_preloaded = False + self._source_identity_compatible_for_last_load = False if self._mx_server_url is None or model is None: return self._fallback_to_disk( @@ -235,9 +264,15 @@ def load_weights(self, checkpoint_dir: str, mapping: Mapping, **kwargs) -> dict[ ) return self._fallback_to_disk(checkpoint_dir, mapping, **kwargs) + source_metadata = self._fetch_source_metadata( + checkpoint_dir, MxClient, _build_trtllm_identity + ) # Pre-transfer compatibility gate: on mismatch, skip the transfer # before any RDMA work starts and fall back to disk. - if not self._source_identity_compatible(checkpoint_dir, MxClient, _build_trtllm_identity): + self._source_identity_compatible_for_last_load = self._source_metadata_identity_compatible( + source_metadata + ) + if not self._source_identity_compatible_for_last_load: return self._fallback_to_disk( checkpoint_dir, mapping, @@ -245,6 +280,32 @@ def load_weights(self, checkpoint_dir: str, mapping: Mapping, **kwargs) -> dict[ **kwargs, ) + layout_status = _metadata_weight_layout_status(source_metadata) + if layout_status is _MxWeightLayoutStatus.UNSUPPORTED: + self._source_identity_compatible_for_last_load = False + return self._fallback_to_disk( + checkpoint_dir, + mapping, + reason=_metadata_unsupported_layout_reason(source_metadata), + **kwargs, + ) + + self._post_transform_weights_preloaded = ( + layout_status is _MxWeightLayoutStatus.POST_TRANSFORM_SUPPORTED + ) + if self._post_transform_weights_preloaded and not allow_post_transform_weights: + self._post_transform_weights_preloaded = False + self._source_identity_compatible_for_last_load = False + return self._fallback_to_disk( + checkpoint_dir, + mapping, + reason=( + "source publishes post-transform weights but this model is " + "not allow-listed for staged MX receiver loading" + ), + **kwargs, + ) + timeout_override = self._resolve_query_timeout_override( checkpoint_dir, MxClient, @@ -263,7 +324,8 @@ def load_weights(self, checkpoint_dir: str, mapping: Mapping, **kwargs) -> dict[ # disk loading remains the correctness path. Preserve the full # traceback so unexpected upstream failures are diagnosable. logger.warning( - f"MX P2P transfer failed; falling back to disk loading.\n{traceback.format_exc()}" + "MX P2P transfer failed; falling back to disk loading.\n" + f"{traceback.format_exc()}" ) return self._fallback_to_disk(checkpoint_dir, mapping, **kwargs) @@ -271,6 +333,24 @@ def load_weights(self, checkpoint_dir: str, mapping: Mapping, **kwargs) -> dict[ fallback_bytes = sum( tensor.numel() * tensor.element_size() for tensor in fallback_weights.values() ) + if self._post_transform_weights_preloaded: + self._post_transform_weights_preloaded = False + self._source_identity_compatible_for_last_load = False + logger.warning( + "MX P2P returned %d fallback weights (%.2f MiB, size mismatch) " + "from a post-transform source at %s. Falling back to a full " + "disk load to avoid mixing transformed P2P tensors with raw " + "fallback tensors before the full post-load transform path.", + len(fallback_weights), + fallback_bytes / (1 << 20), + self._mx_server_url, + ) + return self._fallback_to_disk( + checkpoint_dir, + mapping, + reason="post-transform source returned partial fallback weights", + **kwargs, + ) # Mixed-success case: MX delivered matched tensors into model # params via P2P and returned only size-mismatched tensors for # the standard disk path to apply. Keep the P2P transfer and @@ -285,6 +365,7 @@ def load_weights(self, checkpoint_dir: str, mapping: Mapping, **kwargs) -> dict[ self._mx_server_url, ) self._p2p_succeeded = True + self._post_transform_weights_preloaded = False return fallback_weights self._p2p_succeeded = True @@ -308,7 +389,8 @@ def _resolve_query_timeout_override( return None logger.warning( - f"No MX source is currently registered for {self._resolve_publish_name(checkpoint_dir)}; " + "No MX source is currently registered for " + f"{self._resolve_publish_name(checkpoint_dir)}; " f"using MX_SOURCE_QUERY_TIMEOUT={_MX_SOURCE_QUERY_TIMEOUT_DEFAULT_S} " "for fast disk fallback. Set mx_config.server_query_timeout_s or " "MX_SOURCE_QUERY_TIMEOUT for long-running donor-load deployments." @@ -358,8 +440,17 @@ def _source_identity_compatible( and compatible. `False` when either identity is missing or the identities mismatch, so the caller falls back to disk loading. """ - local_identity = self._local_source_identity source_identity = self._fetch_source_identity(checkpoint_dir, MxClient, build_identity) + return self._source_identity_compatible_with_source(source_identity) + + def _source_metadata_identity_compatible(self, metadata: Optional[dict[str, Any]]) -> bool: + source_identity = _source_identity_from_metadata(metadata) + return self._source_identity_compatible_with_source(source_identity) + + def _source_identity_compatible_with_source( + self, source_identity: Optional[SourceIdentity] + ) -> bool: + local_identity = self._local_source_identity decision = check_weight_sharing_compatibility( local_identity, source_identity, @@ -381,11 +472,79 @@ def _fetch_source_identity( The publisher's identity, or `None` when it cannot be fetched yet (the compatibility gate then rejects P2P and falls back). """ - # TODO(SOURCE-IDENTITY/MX-2): read the publisher's identity from the MX - # metadata channel (get_metadata / WorkerMetadata) once upstream - # exposes a field for it. This is the single seam the gate depends on. + metadata = self._fetch_source_metadata(checkpoint_dir, MxClient, build_identity) + return _source_identity_from_metadata(metadata) + + def _source_metadata_is_post_transform( + self, checkpoint_dir: str, mx_client_type: Type[Any], build_identity: Callable[..., Any] + ) -> bool: + """Whether the selected MX source publishes post-transform bytes. + + The publisher advertises the layout and transform protocol version in + source metadata. Missing layout metadata is treated as pre-transform. + ``load_weights`` uses the full layout status helper to fail closed on + explicit unsupported layouts or protocols before P2P. + """ + metadata = self._fetch_source_metadata(checkpoint_dir, mx_client_type, build_identity) + return _metadata_is_post_transform(metadata) + + def _fetch_source_metadata( + self, checkpoint_dir: str, MxClient: Type[Any], build_identity: Callable[..., Any] + ) -> Optional[dict[str, Any]]: + """Fetch TRT-LLM metadata for the selected MX source, if available.""" + client = None + try: + identity = build_identity(model_name=self._resolve_publish_name(checkpoint_dir)) + client = MxClient(server_url=self._mx_server_url) + for method_name in ("get_source_metadata", "get_metadata", "get_worker_metadata"): + method = getattr(client, method_name, None) + if not callable(method): + continue + try: + metadata = method(identity=identity) + except TypeError: + metadata = method(identity) + metadata_dict = _metadata_to_dict(metadata) + if _metadata_has_trtllm_key(metadata_dict): + return metadata_dict + + list_resp = client.list_sources(identity=identity) + instances = _source_instances_from_list_response(list_resp) + metadata_candidates = [] + for instance in instances: + metadata_dict = _source_instance_metadata(instance) + if metadata_dict: + metadata_candidates.append(metadata_dict) + return self._select_source_metadata(metadata_candidates) + except (AttributeError, RuntimeError, TimeoutError, TypeError, ValueError, grpc.RpcError): + logger.warning( + f"MX source metadata fetch failed; falling back to disk loading.\n" + f"{traceback.format_exc()}" + ) + return None + finally: + if client is not None and hasattr(client, "close"): + client.close() return None + def _select_source_metadata( + self, metadata_candidates: list[dict[str, Any]] + ) -> Optional[dict[str, Any]]: + """Select metadata that matches the receiver identity when possible.""" + if not metadata_candidates: + return None + for metadata in metadata_candidates: + if self._source_metadata_matches_local_identity(metadata): + return metadata + return metadata_candidates[0] + + def _source_metadata_matches_local_identity(self, metadata: dict[str, Any]) -> bool: + local_identity = getattr(self, "_local_source_identity", None) + source_identity = _source_identity_from_metadata(metadata) + if local_identity is None or source_identity is None: + return False + return local_identity.matches(source_identity).matched + def _resolve_publish_name(self, checkpoint_dir: Optional[str]) -> str: return _resolve_mx_model_name(self._model_name, checkpoint_dir) @@ -403,12 +562,14 @@ def publish_as_source( self, model, checkpoint_dir: Optional[str] = None, + *, + source_identity: Optional[SourceIdentity] = None, ) -> None: """Publish this instance's weights so other ranks can pull via P2P. - Called by the integration in `model_loader.py` *before* - `post_load_weights()` so targets receive raw loaded state and - can apply their own post-load transforms. + Called by the integration in `model_loader.py` after + `post_load_weights()` so targets receive the post-transform runtime + layout and, when allow-listed, can skip their own one-shot transforms. Delegates to the upstream `modelexpress.trtllm_live_transfer.publish_model_params` @@ -421,10 +582,18 @@ def publish_as_source( fallback for resolving the `MODEL_NAME` identity when neither `model_name` was passed to the constructor nor `MODEL_NAME` is set in the environment. + source_identity: Source identity built before weight load from + the same lifecycle point on producer and receiver. """ if self._mx_server_url is None: return + if source_identity is None: + logger.warning( + "Skipping MX post-transform publish because SourceIdentity is " + "unavailable; receivers cannot safely verify transformed weights." + ) + return try: from modelexpress.trtllm_live_transfer import ( @@ -444,6 +613,15 @@ def publish_as_source( # (the env-var dance goes away when upstream exports a public identity # builder / publish API). resolved_name = self._resolve_publish_name(checkpoint_dir) + metadata = _build_mx_source_metadata(source_identity) + metadata_kwargs = _publish_metadata_kwargs(publish_model_params, metadata) + if metadata_kwargs is None: + logger.warning( + "Skipping MX post-transform publish because " + "publish_model_params does not accept metadata; receivers " + "cannot safely verify transformed weights." + ) + return env_overrides = { "MODEL_EXPRESS_URL": self._mx_server_url, @@ -463,9 +641,9 @@ def publish_as_source( os.environ[key] = value try: - publish_model_params(model) + publish_model_params(model, **metadata_kwargs) logger.info( - "Published weights to MX server at %s as model=%r", + "Published post-transform weights to MX server at %s as model=%r", self._mx_server_url, resolved_name, ) @@ -484,7 +662,12 @@ def publish_as_source( os.environ[key] = prior_value def post_load_publish( - self, model, *, checkpoint_dir: str, weights_preloaded: bool = False + self, + model, + *, + checkpoint_dir: str, + weights_preloaded: bool = False, + source_identity: Optional[SourceIdentity] = None, ) -> None: """Publish locally loaded weights as an MX source when appropriate. @@ -496,13 +679,17 @@ def post_load_publish( weights_preloaded: Whether this worker already received weights through MX P2P. When true, this worker is an MX receiver and should not republish the same weights as a source. + source_identity: Producer identity serialized into MX metadata so + receivers can verify layout compatibility before transfer. Returns: None. """ if weights_preloaded: return - self.publish_as_source(model, checkpoint_dir=checkpoint_dir) + self.publish_as_source( + model, checkpoint_dir=checkpoint_dir, source_identity=source_identity + ) # --------------------------------------------------------------------------- @@ -560,3 +747,163 @@ def _normalize_model_identity(s: str) -> str: if ancestor.name.startswith("models--"): return ancestor.name[len("models--") :].replace("--", "/") return name or "unknown" + + +def _build_mx_source_metadata(source_identity: Optional[SourceIdentity]) -> dict[str, str]: + metadata = { + _MX_WEIGHT_LAYOUT_METADATA_KEY: _MX_WEIGHT_LAYOUT_POST_TRANSFORM, + _MX_TRANSFORM_PROTOCOL_VERSION_METADATA_KEY: str(_MX_STAGED_TRANSFORM_PROTOCOL_VERSION), + } + if source_identity is not None: + metadata[_MX_SOURCE_IDENTITY_METADATA_KEY] = json.dumps( + source_identity.to_dict(), sort_keys=True + ) + return metadata + + +def _publish_metadata_kwargs( + publish_model_params: Callable[..., Any], + metadata: dict[str, str], +) -> Optional[dict[str, dict[str, str]]]: + try: + signature = inspect.signature(publish_model_params) + except (TypeError, ValueError): + return None + + parameters = signature.parameters + if "metadata" in parameters: + return {"metadata": metadata} + if "worker_metadata" in parameters: + return {"worker_metadata": metadata} + if any(param.kind == inspect.Parameter.VAR_KEYWORD for param in parameters.values()): + return {"metadata": metadata} + return None + + +def _metadata_to_dict(metadata: Any) -> dict[str, Any]: + if metadata is None or isinstance(metadata, (str, bytes)): + return {} + if isinstance(metadata, dict): + return dict(metadata) + + items = getattr(metadata, "items", None) + if callable(items): + try: + return dict(items()) + except (TypeError, ValueError): + pass + + if type(metadata).__module__.startswith("unittest.mock"): + return {} + + attrs = getattr(metadata, "__dict__", None) + if isinstance(attrs, dict): + return dict(attrs) + return {} + + +def _metadata_get(metadata: Optional[dict[str, Any]], key: str) -> Any: + if not metadata: + return None + return metadata.get(key) + + +def _metadata_has_trtllm_key(metadata: dict[str, Any]) -> bool: + return any( + key in metadata + for key in ( + _MX_SOURCE_IDENTITY_METADATA_KEY, + _MX_WEIGHT_LAYOUT_METADATA_KEY, + _MX_TRANSFORM_PROTOCOL_VERSION_METADATA_KEY, + ) + ) + + +def _source_identity_from_metadata(metadata: Optional[dict[str, Any]]) -> Optional[SourceIdentity]: + value = _metadata_get(metadata, _MX_SOURCE_IDENTITY_METADATA_KEY) + if value is None: + return None + + try: + if isinstance(value, SourceIdentity): + return value + if isinstance(value, bytes): + value = value.decode("utf-8") + if isinstance(value, str): + value = json.loads(value) + to_dict = getattr(value, "to_dict", None) + if callable(to_dict): + value = to_dict() + if not isinstance(value, dict): + raise TypeError(f"expected dict-compatible SourceIdentity, got {type(value)!r}") + return SourceIdentity.from_dict(value) + except (json.JSONDecodeError, KeyError, TypeError, ValueError): + logger.warning( + "MX source metadata contains an invalid SourceIdentity; falling back to disk loading." + ) + return None + + +def _metadata_is_post_transform(metadata: Optional[dict[str, Any]]) -> bool: + return ( + _metadata_weight_layout_status(metadata) is _MxWeightLayoutStatus.POST_TRANSFORM_SUPPORTED + ) + + +def _metadata_weight_layout_status(metadata: Optional[dict[str, Any]]) -> _MxWeightLayoutStatus: + layout = _metadata_get(metadata, _MX_WEIGHT_LAYOUT_METADATA_KEY) + if layout is None: + return _MxWeightLayoutStatus.PRE_TRANSFORM + + normalized_layout = str(layout).lower() + if normalized_layout == "pre_transform": + return _MxWeightLayoutStatus.PRE_TRANSFORM + if normalized_layout != _MX_WEIGHT_LAYOUT_POST_TRANSFORM: + return _MxWeightLayoutStatus.UNSUPPORTED + + version = _metadata_get(metadata, _MX_TRANSFORM_PROTOCOL_VERSION_METADATA_KEY) + try: + protocol_version = int(version) + except (TypeError, ValueError): + return _MxWeightLayoutStatus.UNSUPPORTED + if protocol_version != _MX_STAGED_TRANSFORM_PROTOCOL_VERSION: + return _MxWeightLayoutStatus.UNSUPPORTED + return _MxWeightLayoutStatus.POST_TRANSFORM_SUPPORTED + + +def _metadata_unsupported_layout_reason(metadata: Optional[dict[str, Any]]) -> str: + layout = _metadata_get(metadata, _MX_WEIGHT_LAYOUT_METADATA_KEY) + if str(layout).lower() == _MX_WEIGHT_LAYOUT_POST_TRANSFORM: + version = _metadata_get(metadata, _MX_TRANSFORM_PROTOCOL_VERSION_METADATA_KEY) + return ( + "source publishes post-transform weights with unsupported " + f"transform protocol {version!r}" + ) + return f"source publishes unsupported MX weight layout {layout!r}" + + +def _source_instances_from_list_response(list_resp: Any) -> list[Any]: + if isinstance(list_resp, dict): + instances = list_resp.get("instances", []) + else: + instances = getattr(list_resp, "instances", []) + return list(instances or []) + + +def _source_instance_metadata(instance: Any) -> dict[str, Any]: + for candidate in ( + instance, + _metadata_attr(instance, "metadata"), + _metadata_attr(instance, "worker_metadata"), + _metadata_attr(instance, "source_metadata"), + ): + metadata = _metadata_to_dict(candidate) + if _metadata_has_trtllm_key(metadata): + return metadata + return {} + + +def _metadata_attr(instance: Any, name: str) -> Any: + if isinstance(instance, dict): + return instance.get(name) + return getattr(instance, name, None) diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index 6e2b4b532f49..9cf5f716dda1 100755 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -1921,7 +1921,7 @@ def load_weights(self, weights: ConsumableWeightsDict): weight_loader = DeepseekV3WeightLoader(self) weight_loader.load_weights(weights) - def post_load_weights(self): + def setup_aliases(self) -> None: for idx, layer in enumerate( self.model.layers[:self.config.num_hidden_layers]): if idx == self.config.num_hidden_layers - 1: diff --git a/tensorrt_llm/_torch/models/modeling_exaone_moe.py b/tensorrt_llm/_torch/models/modeling_exaone_moe.py index ba8577da9613..40ae3653d6e0 100644 --- a/tensorrt_llm/_torch/models/modeling_exaone_moe.py +++ b/tensorrt_llm/_torch/models/modeling_exaone_moe.py @@ -725,7 +725,7 @@ def load_weights( allow_partial_loading=allow_partial_loading, ) - def post_load_weights(self): + def setup_aliases(self) -> None: # For the cross-layer residual+LN fusion. for idx, layer in enumerate(self.model.layers[: self.config.num_hidden_layers]): if idx == self.config.num_hidden_layers - 1: diff --git a/tensorrt_llm/_torch/models/modeling_glm.py b/tensorrt_llm/_torch/models/modeling_glm.py index 293510b65099..2572ea548e48 100644 --- a/tensorrt_llm/_torch/models/modeling_glm.py +++ b/tensorrt_llm/_torch/models/modeling_glm.py @@ -1074,7 +1074,7 @@ def load_weights(self, weights: ConsumableWeightsDict, allow_partial_loading: bo weight_loader = Glm4WeightLoader(self) weight_loader.load_weights(weights, allow_partial_loading=allow_partial_loading) - def post_load_weights(self): + def setup_aliases(self) -> None: for idx, layer in enumerate(self.model.layers[: self.config.num_hidden_layers]): if idx == self.config.num_hidden_layers - 1: layer.next_layer_layernorm = self.model.norm diff --git a/tensorrt_llm/_torch/models/modeling_gpt_oss.py b/tensorrt_llm/_torch/models/modeling_gpt_oss.py index bc908a2ec014..00b5c77c8951 100644 --- a/tensorrt_llm/_torch/models/modeling_gpt_oss.py +++ b/tensorrt_llm/_torch/models/modeling_gpt_oss.py @@ -631,7 +631,7 @@ def load_weights(self, weights: Dict): else: self.load_hf_weights(weights) - def post_load_weights(self): + def setup_aliases(self) -> None: for idx, layer in enumerate( self.model.block[:self.config.num_hidden_layers]): if idx == 0: diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index d77c22322de9..1150c806de75 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -484,7 +484,7 @@ def __init__( self.input_layernorm = RMSNorm(hidden_size=config.hidden_size, eps=config.rms_norm_eps, dtype=config.torch_dtype) - # When post_load_weights() chains layernorms across layers, + # When setup_aliases() chains layernorms across layers, # this flag is set to True to skip the input layernorm in # forward() since it is handled by the previous layer. self.skip_input_layernorm = False @@ -709,7 +709,7 @@ def __init__( quantize_type="nvfp4" if not self.disable_nvfp4_layernorm_fusion and self.is_nvfp4 and not (differ_pp_stage_with_previous_layer) else None) - # When post_load_weights() chains layernorms across layers, + # When setup_aliases() chains layernorms across layers, # this flag is set to True to skip the input layernorm in # forward() since it is handled by the previous layer. self.skip_input_layernorm = False @@ -983,7 +983,7 @@ def __init__(self, model_config: ModelConfig[LlamaConfig]): self.norm = RMSNorm(hidden_size=config.hidden_size, eps=config.rms_norm_eps, dtype=config.torch_dtype) - # When post_load_weights() chains the final norm into the + # When setup_aliases() chains the final norm into the # last decoder layer, this flag is set to True to skip # applying it again in forward(). self.skip_norm = False @@ -1088,7 +1088,7 @@ def __init__(self, model_config: ModelConfig[LlamaConfig]): self.norm = RMSNorm(hidden_size=config.hidden_size, eps=config.rms_norm_eps, dtype=config.torch_dtype) - # When post_load_weights() chains the final norm into the + # When setup_aliases() chains the final norm into the # last decoder layer, this flag is set to True to skip # applying it again in forward(). self.skip_norm = False @@ -1140,7 +1140,7 @@ def __init__( ): super().__init__(LlamaModel(model_config), model_config) - def post_load_weights(self): + def setup_aliases(self) -> None: for idx, layer in enumerate( self.model.layers[:self.config.num_hidden_layers]): if idx == self.config.num_hidden_layers - 1: @@ -1564,7 +1564,7 @@ def load_weights(self, weights: Dict, weight_mapper: BaseWeightMapper): if had_mm_encoder: self.mm_encoder = saved_mm_encoder - def post_load_weights(self): + def setup_aliases(self) -> None: for idx, layer in enumerate( self.model.layers[:self.config.num_hidden_layers]): if idx == self.config.num_hidden_layers - 1: diff --git a/tensorrt_llm/_torch/models/modeling_llama_min_latency.py b/tensorrt_llm/_torch/models/modeling_llama_min_latency.py index aafd9ef057e6..c13dbbb41864 100644 --- a/tensorrt_llm/_torch/models/modeling_llama_min_latency.py +++ b/tensorrt_llm/_torch/models/modeling_llama_min_latency.py @@ -1,3 +1,18 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from collections.abc import Callable from typing import Dict, List, Optional, Tuple, Union @@ -308,7 +323,7 @@ def __init__(self, # After loading both gate_up_proj and down_proj, we need to set the scales needed by the special kernels and by # the trtllm-gen gemm+swiglu kernel. - def post_load_weights(self): + def cache_derived_state(self) -> None: if self.gate_up_proj.has_fp8_qdq: # For the special gemm+swiglu kernel, we need to set the inverse of the output scale, which is the inverse # of down_proj's combined input scale. @@ -317,6 +332,9 @@ def post_load_weights(self): # combined input scale times inv_output_scale. self.gate_up_proj.trtllm_gen_global_scale = self.gate_up_proj.combined_scale * self.gate_up_proj.inv_output_scale + def post_load_weights(self) -> None: + self.cache_derived_state() + def forward( self, x: Union[torch.Tensor, Fp4QuantizedTensor], @@ -566,7 +584,7 @@ def __init__( dtype=model_config.pretrained_config.torch_dtype, quant_config=None) - def post_load_weights(self): + def cache_derived_state(self) -> None: # Set min-latency quant scales for routed experts if we plan to use min-latency MoE kernels. # This is because the routed experts' input scale is after the score multiplication, so we must use the # pre-score scaling input scale, which happens to be shared expert's input scale. @@ -582,6 +600,9 @@ def post_load_weights(self): fc1_input_dequant=pre_score_scaling_input_scale, ) + def post_load_weights(self) -> None: + self.cache_derived_state() + def compute_routed_output( self, hidden_states, diff --git a/tensorrt_llm/_torch/models/modeling_qwen3_moe.py b/tensorrt_llm/_torch/models/modeling_qwen3_moe.py index 43b4499f4874..571e3fe503c0 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen3_moe.py +++ b/tensorrt_llm/_torch/models/modeling_qwen3_moe.py @@ -417,7 +417,7 @@ def __init__( ) self.preload_weight_modules = self.model.preload_weight_modules - def post_load_weights(self): + def setup_aliases(self) -> None: for idx, layer in enumerate( self.model.layers[:self.config.num_hidden_layers]): if idx == self.config.num_hidden_layers - 1: diff --git a/tensorrt_llm/_torch/models/modeling_qwen3_next.py b/tensorrt_llm/_torch/models/modeling_qwen3_next.py index 345f8eacee6e..7667972804ad 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen3_next.py +++ b/tensorrt_llm/_torch/models/modeling_qwen3_next.py @@ -980,7 +980,7 @@ def load_weights(self, weights: dict, weight_mapper: BaseWeightMapper): new_weights = weight_mapper.preprocess_weights(weights) super().load_weights(new_weights, weight_mapper) - def post_load_weights(self): + def setup_aliases(self) -> None: for idx, layer in enumerate( self.model.layers[:self.config.num_hidden_layers]): if idx == self.config.num_hidden_layers - 1: diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index 28d5cff5c3d8..b4790b200155 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -1261,6 +1261,7 @@ def __init__( self.layer_idx = layer_idx self.layer_idx_str = str(layer_idx) self.dtype = dtype + self._weights_transformed = False self.hidden_size = hidden_size self.num_heads = num_attention_heads @@ -1648,6 +1649,7 @@ def create_weights(self): else: self.k_b_proj_trans_scale = None self.v_b_proj_scale = None + self._weights_transformed = False def apply_rope( self, @@ -3027,7 +3029,9 @@ def resmooth_parameters(self, return weight_param, scale_param - def post_load_weights(self): + def transform_weights(self) -> None: + if self._weights_transformed: + return has_fp8_block_scales = ( self.kv_b_proj.quant_config and self.kv_b_proj.quant_config.quant_mode.has_fp8_block_scales()) @@ -3040,3 +3044,10 @@ def post_load_weights(self): self.v_b_proj, self.v_b_proj_scale = self.resmooth_parameters( self.v_b_proj, self.v_b_proj_scale, recipe=(1, 128, 128)) + self._weights_transformed = True + + def cache_derived_state(self) -> None: + self._weights_transformed = True + + def post_load_weights(self) -> None: + self.transform_weights() diff --git a/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py b/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py index bdef874a71b2..d4b36d750a7c 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py +++ b/tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """ ConfigurableMoE: Composition-based Configurable MoE Module @@ -73,7 +72,6 @@ class ConfigurableMoE(MoE): # authoritative check -- if the chosen inner backend doesn't opt in, its # ``MoE.__init__`` will still raise. _supports_non_divisible_ep: bool = True - """ Configurable MoE layer using composition pattern with automatic configuration @@ -650,17 +648,39 @@ def load_weights(self, weights: List[Dict], allow_partial_loading: bool = False) assert hasattr(self.backend, "load_weights"), ( f"Backend {self.backend.__class__.__name__} must implement load_weights()" ) + self._weights_transformed = False return self.backend.load_weights(weights, allow_partial_loading) - def post_load_weights(self): + def transform_weights(self) -> None: + """ + Transform weights - delegated to backend + + """ + if getattr(self, "_weights_transformed", False): + return + assert hasattr(self.backend, "transform_weights"), ( + f"Backend {self.backend.__class__.__name__} must implement transform_weights()" + ) + self.backend.transform_weights() + self._weights_transformed = True + + def cache_derived_state(self) -> None: """ - Post load weights processing - delegated to backend + Cache derived state - delegated to backend """ - assert hasattr(self.backend, "post_load_weights"), ( - f"Backend {self.backend.__class__.__name__} must implement post_load_weights()" + assert hasattr(self.backend, "cache_derived_state"), ( + f"Backend {self.backend.__class__.__name__} must implement cache_derived_state()" ) - return self.backend.post_load_weights() + self.backend.cache_derived_state() + + def post_load_weights(self) -> None: + """ + Backward-compatible staged post-load processing - delegated to backend + + """ + self.transform_weights() + self.cache_derived_state() def process_weights_after_loading(self): """ diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl_b12x.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl_b12x.py index 392c076d954f..f9cc59724ec1 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl_b12x.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cute_dsl_b12x.py @@ -68,7 +68,7 @@ class CuteDslB12xFusedMoE(CuteDslFusedMoE): ``_get_quant_method``). The inherited CUTLASS NVFP4 layout is finalised by the base class, and the b12x-shaped tensors (un-normalised FP8 SF, ``convert_sf_to_mma_layout`` reshape, ``B12xMoEWrapper`` instance) are - materialised on top by the quant method's ``post_load_weights``. Both + materialised on top by the quant method's ``transform_weights``. Both layouts coexist in memory and the dispatcher picks per call based on ``x.shape[0]``. @@ -173,7 +173,7 @@ def _route_to_cutlass(self, x) -> bool: return isinstance(x, torch.Tensor) and x.shape[0] >= self._PREFILL_VIA_CUTLASS_THRESHOLD # ``post_load_weights`` is inherited from ``CutlassFusedMoE`` and - # dispatches to ``self.quant_method.post_load_weights(self)`` — for this + # dispatches to ``self.quant_method.transform_weights(self)`` — for this # backend ``self.quant_method`` is ``NVFP4CuteDslB12xFusedMoEMethod`` # (see ``_get_quant_method`` override), which performs the SF un-normalization, # ``convert_sf_to_mma_layout`` reshape, ``B12xMoEWrapper`` instantiation, diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py old mode 100755 new mode 100644 index 85fdac23361d..b9365f604c0d --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py @@ -1,3 +1,18 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import inspect import os from functools import cached_property @@ -1576,5 +1591,6 @@ def load_weights(self, self.quant_method.load_weights(self, weights, self.weight_loading_mode, **kargs) - def post_load_weights(self): - self.quant_method.post_load_weights(self) + def post_load_weights(self) -> None: + self.transform_weights() + self.cache_derived_state() diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_densegemm.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_densegemm.py index f05f20215adf..7cb54cf1300b 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_densegemm.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_densegemm.py @@ -290,8 +290,9 @@ def load_weights(self, weights: List[Dict], allow_partial_loading: bool = False) f"got {self.quant_config.quant_mode}." ) - def post_load_weights(self): - self.quant_method.post_load_weights(self) + def post_load_weights(self) -> None: + self.transform_weights() + self.cache_derived_state() def _transform_w2_weight_scale_for_min_latency(self): """Transform w2_weight_scale for minimum latency path optimization.""" diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_triton.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_triton.py index 4c46490b619d..b01a4bf16245 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_triton.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_triton.py @@ -1368,7 +1368,7 @@ def _maybe_remove_padding(gemm_output, expected_size): return gemm2_output - def post_load_weights(self, module: torch.nn.Module): + def transform_weights(self, module: torch.nn.Module) -> None: if 'w3_w1_weight' in module._parameters: w31_scale = shuffle_weight_for_activation_kernel( module.fc31_dequant.data) @@ -1382,7 +1382,7 @@ def post_load_weights(self, module: torch.nn.Module): module.fc31_input_dequant = None module.fc2_input_dequant = None - super().post_load_weights(module) + super().transform_weights(module) class TritonFusedMoE(MoE): @@ -1585,5 +1585,6 @@ def load_weights(self, self.quant_method.load_weights(self, weights, self.weight_loading_mode) - def post_load_weights(self): - self.quant_method.post_load_weights(self) + def post_load_weights(self) -> None: + self.transform_weights() + self.cache_derived_state() diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py index e9dac388d34b..89a01ba46a5c 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py @@ -524,8 +524,9 @@ def load_weights(self, self.quant_method.load_weights(self, weights, self.weight_loading_mode, **kargs) - def post_load_weights(self): - self.quant_method.post_load_weights(self) + def post_load_weights(self) -> None: + self.transform_weights() + self.cache_derived_state() def quantize_input(self, x, post_quant_comm: bool = True): """Quantize inputs prior to post-communication (alltoall/allgather) or before MoE computation. diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py old mode 100755 new mode 100644 index 1edbf54c2914..05aa7d91a80b --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py @@ -1,3 +1,18 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import inspect import os from typing import Dict, List, Optional, Tuple, Union @@ -949,8 +964,9 @@ def load_weights(self, self.quant_method.load_weights(self, weights, self.weight_loading_mode, **kargs) - def post_load_weights(self): - self.quant_method.post_load_weights(self) + def post_load_weights(self) -> None: + self.transform_weights() + self.cache_derived_state() def forward_fake( self, diff --git a/tensorrt_llm/_torch/modules/fused_moe/interface.py b/tensorrt_llm/_torch/modules/fused_moe/interface.py index c57386fae2d9..e00deec91fca 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/interface.py +++ b/tensorrt_llm/_torch/modules/fused_moe/interface.py @@ -827,8 +827,18 @@ def load_weights(self, """ raise NotImplementedError - def post_load_weights(self): - pass + def transform_weights(self) -> None: + if getattr(self, "_weights_transformed", False): + return + self.quant_method.transform_weights(self) + self._weights_transformed = True + + def cache_derived_state(self) -> None: + self.quant_method.cache_derived_state(self) + + def post_load_weights(self) -> None: + self.transform_weights() + self.cache_derived_state() def process_weights_after_loading(self): """ diff --git a/tensorrt_llm/_torch/modules/fused_moe/mega_moe/mega_moe_deepgemm.py b/tensorrt_llm/_torch/modules/fused_moe/mega_moe/mega_moe_deepgemm.py index a31b22419bd6..3f715165f19c 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/mega_moe/mega_moe_deepgemm.py +++ b/tensorrt_llm/_torch/modules/fused_moe/mega_moe/mega_moe_deepgemm.py @@ -518,7 +518,8 @@ def load_weights(self, weights: List[Dict], allow_partial_loading: bool = False) def post_load_weights(self) -> None: if self.quant_method is None: self.create_weights() - self.quant_method.post_load_weights(self) + self.transform_weights() + self.cache_derived_state() # ------------------------------------------------------------------ # MoE-contract methods diff --git a/tensorrt_llm/_torch/modules/fused_moe/quantization.py b/tensorrt_llm/_torch/modules/fused_moe/quantization.py index 42c9c84e2408..f18c33296b67 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/quantization.py +++ b/tensorrt_llm/_torch/modules/fused_moe/quantization.py @@ -506,7 +506,7 @@ def load_weights(self, # before the next layer is loaded. This prevents accumulation of all # layers' shared weight tensors in host memory simultaneously. # Partial-loading callers (e.g. RLHF reload) instead rely on - # ``post_load_weights`` to finalize once the loading sequence ends. + # ``transform_weights`` to finalize once the loading sequence ends. self._finalize_shared_weights(module) def _prepare_shared_weights_for_finalization(self, module: torch.nn.Module): @@ -525,14 +525,14 @@ def _finalize_shared_weights(self, module: torch.nn.Module): shared memory, then delete the private CPU tensor copies. Calling this at the end of each layer's ``load_weights`` (rather than - deferring to ``post_load_weights``) prevents all layers' CPU tensors + deferring to ``transform_weights``) prevents all layers' CPU tensors from accumulating in host memory simultaneously. With fewer GPUs per node each rank is responsible for more experts, so the accumulated private tensors can easily exceed available host memory. This method is idempotent: if the per-layer ``local_shared_*`` tensors have already been finalized (and deleted), it is a no-op. This lets - ``post_load_weights`` invoke it as a safety net for callers that pass + ``transform_weights`` invoke it as a safety net for callers that pass ``allow_partial_loading=True`` (e.g. the RLHF reload path), without double-finalizing in the eager path. """ @@ -559,12 +559,14 @@ def _finalize_shared_weights(self, module: torch.nn.Module): module.register_all_parameter_slot_and_to_fix_weight_fns(weight_fns) module.layer_load_balancer.host_tensor_sharer.finalize_layer_weights() - def post_load_weights(self, module: torch.nn.Module): + def transform_weights(self, module: torch.nn.Module) -> None: # Safety net for deferred-finalization callers (e.g. RLHF reload, which # passes allow_partial_loading=True and so skips the eager per-layer # finalization in load_weights). Idempotent when finalization already # ran eagerly. self._finalize_shared_weights(module) + + def cache_derived_state(self, module: torch.nn.Module) -> None: if hasattr(module, "layer_load_balancer") and module.layer_load_balancer: module.layer_load_balancer.set_initial_weight_assignments( @@ -572,6 +574,10 @@ def post_load_weights(self, module: torch.nn.Module): # Re-setup quant scales after loading weights as the tensors may have been modified. self.setup_quant_scales(module) + def post_load_weights(self, module: torch.nn.Module) -> None: + self.transform_weights(module) + self.cache_derived_state(module) + def load_quant_scales(self, module: torch.nn.Module, weights: List[Dict]): pass @@ -775,10 +781,10 @@ def process_weights_after_loading(self, module: torch.nn.Module): module.rebuild_tensor_metadata) module._trtllm_gen_layout_transform_pending = False - def post_load_weights(self, module: torch.nn.Module): + def transform_weights(self, module: torch.nn.Module) -> None: if getattr(module, "_trtllm_gen_layout_transform_pending", False): self.process_weights_after_loading(module) - super().post_load_weights(module) + super().transform_weights(module) def load_expert_fc31_input_scale_fp8_qdq(w1_input_scale, w3_input_scale, @@ -1006,8 +1012,8 @@ def process_weights_after_loading(self, module: torch.nn.Module): delattr(module, 'tmp_fc31_input_scale') delattr(module, 'tmp_fc2_input_scale') - def post_load_weights(self, module): - super().post_load_weights(module) + def transform_weights(self, module: torch.nn.Module) -> None: + super().transform_weights(module) # Padding weights to meet FP8 GEMM alignment requirements. def _maybe_padding_weights(tensor: torch.Tensor, row_alignment: int, @@ -1271,11 +1277,11 @@ def _prepare_shared_weights_for_finalization(self, module: torch.nn.Module): transformed_shared_w2_scale.cpu()) super()._prepare_shared_weights_for_finalization(module) - def post_load_weights(self, module: torch.nn.Module): - super().post_load_weights(module) + def transform_weights(self, module: torch.nn.Module) -> None: + super().transform_weights(module) if self._needs_e8m0_resmooth(): - logger.debug("Resmoothing FP8 weights in post_load_weights") + logger.debug("Resmoothing FP8 weights in transform_weights") resmoothed_w3_w1_weight, transformed_w3_w1_scale = resmooth_and_transform_fp8_scale( module.w3_w1_weight, module.w3_w1_weight_scaling_factor) module.w3_w1_weight.data.copy_(resmoothed_w3_w1_weight) @@ -1291,7 +1297,6 @@ def post_load_weights(self, module: torch.nn.Module): "w2_weight_scaling_factor", transformed_w2_scale, module.rebuild_tensor_metadata) - self.setup_quant_scales(module) class INT8WoqPerChannelFusedMoEMethod(FusedMoEMethodBase): @@ -3075,11 +3080,11 @@ class NVFP4CuteDslB12xFusedMoEMethod(NVFP4CutlassFusedMoEMethod): """NVFP4 quant method for the FlashInfer B12x MoE backend (SM120 / SM121). Inherits the full CUTLASS NVFP4 weight pipeline (cat + pad + - block_scale_interleave + setup_quant_scales) so the backend's + block_scale_interleave) so the backend's hybrid prefill path can continue to consume the standard CUTLASS NVFP4 GroupGEMM layout via the inherited ``CutlassFusedMoE.run_moe``. - On top of that base layout, ``post_load_weights`` materialises the + On top of that base layout, ``transform_weights`` materialises the b12x-specific weight tensors: SF un-normalization (multiply per-block FP8 scales by ``weight_scale_2 = fc_alpha * fc_input_scale``), ``convert_sf_to_mma_layout`` reshape, per-expert ``w*_alpha = 1 / @@ -3098,11 +3103,12 @@ class NVFP4CuteDslB12xFusedMoEMethod(NVFP4CutlassFusedMoEMethod): ActivationType.Swiglu: "silu", } - def post_load_weights(self, module: torch.nn.Module): - # Base class handles shared-weight finalize, load-balancer init, - # and setup_quant_scales. Leaves the standard CUTLASS NVFP4 - # weight + SF layout in place for the inherited prefill path. - super().post_load_weights(module) + def transform_weights(self, module: torch.nn.Module) -> None: + # Base class handles shared-weight finalization. The cache stage + # handles load-balancer assignments and setup_quant_scales. + # Leaves the standard CUTLASS NVFP4 weight + SF layout in place + # for the inherited prefill path. + super().transform_weights(module) try: from flashinfer import B12xMoEWrapper @@ -5342,8 +5348,8 @@ def create_weights(self, module: torch.nn.Module): def setup_quant_scales(self, module: torch.nn.Module): module.quant_scales = tuple() - def post_load_weights(self, module: torch.nn.Module): - super().post_load_weights(module) + def cache_derived_state(self, module: torch.nn.Module) -> None: + super().cache_derived_state(module) # Create a proxy weight of unpadded size; dtype does not matter w1_weight = torch.empty([module.intermediate_size, module.hidden_size]) # Calculate alignment @@ -5934,7 +5940,7 @@ def load_weights( # ``initial_local_expert_ids`` which already populate the device # weight tensors above). We allocate matching CPU tensors with the # same per-expert shape/dtype as the device weights and load the - # same MXFP4 byte layout into them. ``post_load_weights`` will later + # same MXFP4 byte layout into them. ``transform_weights`` will later # transform these into DG-required form and register them with the # host_tensor_sharer so peer ranks can read them during migration. # The CPU staging is required because EPLB's host_tensor_sharer @@ -5972,8 +5978,25 @@ def load_weights( module.local_shared_w2_scale_tensors, ) + self._clear_transformed_weight_cache(module) module._weights_loaded = True + @staticmethod + def _clear_transformed_weight_cache(module: torch.nn.Module) -> None: + """Drop DG-derived tensors so fresh raw weights rebuild the native layout.""" + for attr in ( + "_t_l1", + "_t_l2", + "_t_l1_weight", + "_t_l1_scale", + "_t_l1_scale_slot", + "_t_l2_weight", + "_t_l2_scale", + "_t_l2_scale_slot", + ): + if hasattr(module, attr): + setattr(module, attr, None) + def _transform_weights_for_mega_moe( self, module: torch.nn.Module, @@ -6025,13 +6048,12 @@ def _transform_weights_for_mega_moe( return dg.transform_weights_for_mega_moe((l1_weight, l1_sf), (l2_weight, l2_sf)) - def post_load_weights(self, module: torch.nn.Module) -> None: + def transform_weights(self, module: torch.nn.Module) -> None: """Transform loaded MXFP4 weights into DG-native form. Pipeline (each step is independent and idempotent on its own guard): 1. ``_transform_main_weights`` - DG-form L1/L2 + EPLB-friendly slot views 2. ``_setup_shared_weights_for_eplb`` - host-side shared copies for dynamic EPLB - 3. ``_attach_initial_weight_assignments`` - tell load_balancer the initial layout The NVLink SymmBuffer (forward-time activation workspace, not weight storage) is allocated by ``MegaMoEDeepGemm.__init__`` itself @@ -6039,10 +6061,12 @@ def post_load_weights(self, module: torch.nn.Module) -> None: ``symm_mem.rendezvous`` collective is safe; see ``MegaMoEDeepGemm._alloc_symm_buffer``. """ - assert module._weights_loaded, "post_load_weights before load_weights" + assert module._weights_loaded, "transform_weights before load_weights" self._transform_main_weights(module) self._setup_shared_weights_for_eplb(module) - self._attach_initial_weight_assignments(module) + + def cache_derived_state(self, module: torch.nn.Module) -> None: + super().cache_derived_state(module) def _transform_main_weights(self, module: torch.nn.Module) -> None: """Build DG-form ``_t_l1`` / ``_t_l2`` and EPLB-friendly slot views. @@ -6155,11 +6179,3 @@ def refresh_deepgemm_scale_views(): ): delattr(module, attr) module.layer_load_balancer.host_tensor_sharer.finalize_layer_weights() - - @staticmethod - def _attach_initial_weight_assignments(module: torch.nn.Module) -> None: - """Hand the initial expert->slot assignments to the load balancer.""" - if hasattr(module, - "layer_load_balancer") and module.layer_load_balancer: - module.layer_load_balancer.set_initial_weight_assignments( - module.initial_global_assignments) diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index aae3f3a65ab2..984445f689e8 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -380,13 +380,17 @@ def load_weights(self, if not allow_partial_loading: self.process_weights_after_loading(module) - def post_load_weights(self, module: Linear): - pass + def transform_weights(self, module: Linear) -> None: + return None - def load_weight_scales(self, weights: List[Dict], *args, **kwargs): + def post_load_weights(self, module: Linear) -> None: + self.transform_weights(module) + + def load_weight_scales(self, weights: List[Dict], *args, **kwargs) -> None: """ Load quantized weight scales from the checkpoint. """ + return None @abstractmethod def load_weights_vanilla(self, @@ -420,7 +424,7 @@ def load_weights_fused_gate_up_linear( """ raise NotImplementedError - def process_weights_after_loading(self, module: Linear): + def process_weights_after_loading(self, module: Linear) -> None: """ Process quantization weights and scales after loading weights. """ @@ -434,23 +438,27 @@ def process_weights_after_loading(self, module: Linear): else: raise ValueError(f'unsupported weight mode: {weight_mode}') - def process_weights_after_loading_vanilla(self, module: Linear): + def process_weights_after_loading_vanilla(self, module: Linear) -> None: """ Process quantization weights and scales after loading weights for vanilla linear layer. """ + return None - def process_weights_after_loading_fused_qkv_linear(self, module: Linear): + def process_weights_after_loading_fused_qkv_linear(self, + module: Linear) -> None: """ Process quantization weights and scales after loading weights for fused QKV linear layer. """ + return None def process_weights_after_loading_fused_gate_up_linear( - self, module: Linear): + self, module: Linear) -> None: """ Process quantization weights and scales after loading weights for fused gate up linear layer. """ + return None - def pre_reload_weights(self, module: Linear): + def pre_reload_weights(self, module: Linear) -> None: """ Pre-reload weights for the linear layer. """ @@ -1241,8 +1249,8 @@ def load_weights_fused_gate_up_linear( copy_weight_shard(module.weight_scale, scale, shard_offset, shard_size) - def post_load_weights(self, module: Linear): - super().post_load_weights(module) + def transform_weights(self, module: Linear) -> None: + super().transform_weights(module) if (is_sm_100f() and not (module.use_cute_dsl_blockscaling_mm or module.disable_deep_gemm)) or \ get_sm_version() == 120: @@ -1821,9 +1829,9 @@ def process_weights_after_loading_fused_gate_up_linear( torch.ops.trtllm.block_scale_interleave(ws_swapped), requires_grad=False) - def post_load_weights(self, module: Linear): + def transform_weights(self, module: Linear) -> None: """Pad weight and weight_scale tensors to meet torch trtllm NVFP4 GEMM alignment requirements.""" - super().post_load_weights(module) + super().transform_weights(module) row_alignment, col_alignment = 32, 16 row_pad_size = (row_alignment - module.weight.size(0)) % row_alignment col_pad_size = (col_alignment - module.weight.size(1)) % col_alignment @@ -1873,10 +1881,10 @@ class W4A16NVFP4LinearMethod(NVFP4LinearMethod): its fused path is SM>=100-gated upstream. """ - def post_load_weights(self, module: Linear): + def transform_weights(self, module: Linear) -> None: # Skip parent's 32x16 weight padding (apply() accepts [N, K/2] as-is) # and un-swizzle per-block scale once at load. - LinearMethodBase.post_load_weights(self, module) + LinearMethodBase.transform_weights(self, module) pad_rows = fp4_utils.pad_up(module.out_features, 128) pad_cols = fp4_utils.pad_up( module.in_features // module.scaling_vector_size, 4) @@ -2914,6 +2922,7 @@ def __init__( dtype=self.dtype) if reduce_output else None self._weights_created = False + self._weights_transformed = False self.reduce_output = reduce_output self.use_custom_cublas_mm = use_custom_cublas_mm self.use_cute_dsl_bf16_gemm = use_cute_dsl_bf16_gemm @@ -2966,6 +2975,7 @@ def create_weights(self): self.dtype) self._weights_created = True + self._weights_transformed = False @property def has_any_quant(self): @@ -3127,6 +3137,7 @@ def load_weights(self, assert allow_partial_loading is False, ( f"{type(self.quant_method).__name__} does not support " "allow_partial_loading") + self._weights_transformed = False self.quant_method.load_weights( self, weights, @@ -3136,8 +3147,17 @@ def load_weights(self, def process_weights_after_loading(self): self.quant_method.process_weights_after_loading(self) - def post_load_weights(self): - self.quant_method.post_load_weights(self) + def transform_weights(self) -> None: + if self._weights_transformed: + return + self.quant_method.transform_weights(self) + self._weights_transformed = True + + def cache_derived_state(self) -> None: + self._weights_transformed = True + + def post_load_weights(self) -> None: + self.transform_weights() def pre_reload_weights(self): assert hasattr( diff --git a/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py b/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py index 03f9d8068679..a3813f1099c0 100644 --- a/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py +++ b/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py @@ -253,8 +253,8 @@ def __init__( self.aux_steram = torch.cuda.Stream() self.events = [torch.cuda.Event(), torch.cuda.Event()] - def post_load_weights(self): - """Post-process after loading weights.""" + def cache_derived_state(self) -> None: + """Recompute state derived from loaded weights.""" if (self.norm.is_nvfp4 and fused_gated_rmsnorm_quant_shape_ok( self.norm.hidden_size, self.norm.group_size) and self.norm.nvfp4_scale is None): @@ -270,6 +270,9 @@ def post_load_weights(self): p=self.head_dim) self._D_expanded = repeat(self.D, "h -> h p", p=self.head_dim) + def post_load_weights(self) -> None: + self.cache_derived_state() + def _try_attach_nvfp4_scale(self): """Attach input_scale from out_proj to norm for fused RMSNorm+Quant.""" diff --git a/tensorrt_llm/_torch/pyexecutor/model_loader.py b/tensorrt_llm/_torch/pyexecutor/model_loader.py index 6df540b88487..2ca1389f38f0 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_loader.py +++ b/tensorrt_llm/_torch/pyexecutor/model_loader.py @@ -1,3 +1,6 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + import copy import inspect import os @@ -24,7 +27,7 @@ from ...llmapi.llm_args import LoadFormat from ..model_config import ModelConfig -from ..models import AutoModelForCausalLM +from ..models import AutoModelForCausalLM, LlamaForCausalLM from ..models.checkpoints.base_checkpoint_loader import BaseCheckpointLoader from ..models.modeling_utils import (DecoderModelForCausalLM, MetaInitMode, timing) @@ -264,6 +267,10 @@ class ModelLoader: Handles the loading, configuration, and weight initialization of a PyTorch model. This class isolates model loading logic from the main execution engine. """ + _MX_STAGED_RECEIVER_TRANSFORM_PROTOCOL_VERSION = 1 + _MX_STAGED_RECEIVER_ALLOWLIST = frozenset({ + (LlamaForCausalLM, _MX_STAGED_RECEIVER_TRANSFORM_PROTOCOL_VERSION) + }) def __init__(self, llm_args: TorchLlmArgs, @@ -485,8 +492,9 @@ def init_meta_tensor(t: torch.Tensor): # post_load_* hooks itself, so the shared post-load block below # must skip them. RW handles them inside `mem_pool_scope` so the # committed pool reflects the post-post_load layout; RO runs - # `module.post_load_weights()` before `materialize_module` to - # wire aliases prior to zero-copy mapping. + # `setup_aliases()` before `materialize_module` to wire aliases + # prior to zero-copy mapping, then refreshes derived state after + # real GMS tensors are bound. gms_post_load_handled = False if load_format == LoadFormat.AUTO: # Pass model= so format-specific loaders (e.g. MX) can @@ -499,6 +507,9 @@ def init_meta_tensor(t: torch.Tensor): # Generic loaders ignore it; MXCheckpointLoader pops it. "source_identity": self._source_identity, } + if checkpoint_loader.checkpoint_format == "MX": + load_weights_kwargs["allow_post_transform_weights"] = ( + self._can_accept_mx_post_transform_weights(model)) if hasattr(model, 'llm_checkpoint_dir'): weights = checkpoint_loader.load_weights( @@ -611,11 +622,19 @@ def init_meta_tensor_in_pool(t: torch.Tensor): # parameter buffers. Keeping the call shape # consistent here avoids forgetting it when MX+GMS # composition lands later. + load_weights_kwargs = { + "mapping": self.mapping, + "model": model, + "source_identity": self._source_identity, + } + if checkpoint_loader.checkpoint_format == "MX": + allow_post_transform_weights = ( + self._can_accept_mx_post_transform_weights( + model)) + load_weights_kwargs[ + "allow_post_transform_weights"] = allow_post_transform_weights weights = checkpoint_loader.load_weights( - weight_source, - mapping=self.mapping, - model=model, - source_identity=self._source_identity) + weight_source, **load_weights_kwargs) # `weights` may be: # - non-empty dict: standard mapping pipeline runs @@ -672,17 +691,17 @@ def init_meta_tensor_in_pool(t: torch.Tensor): # narrow-scope and commit-ordering concerns. checkpoint_loader.post_load_apply( model, weights_preloaded=weights_preloaded) - checkpoint_loader.post_load_publish( + + mx_staged_receiver_path = self._should_run_mx_staged_receiver_path( + checkpoint_loader, model, - checkpoint_dir=checkpoint_dir, weights_preloaded=weights_preloaded) - - for module in model.modules(): - if hasattr( - module, - 'post_load_weights') and not getattr( - module, '_weights_removed', False): - module.post_load_weights() + if mx_staged_receiver_path: + self._setup_aliases(model) + self._mark_weights_transformed(model) + self._walk_cache_state(model) + else: + self._walk_full_post_load(model) # Defensive last-mile sweep: catches strays from # C++ ops that bypassed the active torch @@ -697,6 +716,12 @@ def init_meta_tensor_in_pool(t: torch.Tensor): # cached size doesn't show as live in memory accounting. torch.cuda.empty_cache() + self._post_load_publish( + checkpoint_loader, + model, + checkpoint_dir=checkpoint_dir, + weights_preloaded=weights_preloaded) + # Pool closed. Commit the post-post_load layout. gms_backend.finalize_write(model) gms_post_load_handled = True @@ -717,22 +742,23 @@ def init_meta_tensor_in_pool(t: torch.Tensor): # Hook order: # 1. `post_load_apply`: format-specific apply # work (e.g., MX preshard markers). - # 2. Per-module `post_load_weights`: creates - # aliases/derived parameter attributes BEFORE - # `materialize_module` walks the final module - # tree (including `draft_model` for spec dec). - # 3. `materialize_module`: zero-copy bind GMS + # 2. Per-module `setup_aliases`: creates structural + # aliases BEFORE `materialize_module` walks the + # final module tree (including `draft_model` for + # spec dec). + # 3. SourceIdentity gate: STRICT pre-materialize + # compatibility check (GMS has no disk fallback). + # 4. `materialize_module`: zero-copy bind GMS # pool storage onto the model parameters. - # 4. `post_load_publish`: any receiver-side + # 5. Per-module `cache_derived_state`: recompute + # Python-side state from real, materialized + # tensors without re-running one-shot transforms. + # 6. `post_load_publish`: any receiver-side # publish (no-op via the receiver guard). checkpoint_loader.post_load_apply( model, weights_preloaded=True) - for module in model.modules(): - if hasattr(module, - 'post_load_weights') and not getattr( - module, '_weights_removed', False): - module.post_load_weights() + self._setup_aliases(model) # Pre-materialize compatibility gate. GMS has no # disk-fallback path, so a mismatch raises under STRICT @@ -740,11 +766,12 @@ def init_meta_tensor_in_pool(t: torch.Tensor): self._check_gms_source_identity(gms_backend) gms_backend.materialize_module(model) + self._walk_cache_state(model) - checkpoint_loader.post_load_publish( - model, - checkpoint_dir=checkpoint_dir, - weights_preloaded=True) + self._post_load_publish(checkpoint_loader, + model, + checkpoint_dir=checkpoint_dir, + weights_preloaded=True) gms_post_load_handled = True logger.info("LoadFormat.GMS (RO): materialized weights") else: @@ -779,12 +806,24 @@ def init_meta_tensor_in_pool(t: torch.Tensor): if not gms_post_load_handled: checkpoint_loader.post_load_apply( model, weights_preloaded=weights_preloaded) - checkpoint_loader.post_load_publish( + mx_staged_receiver_path = self._should_run_mx_staged_receiver_path( + checkpoint_loader, model, - checkpoint_dir=checkpoint_dir, weights_preloaded=weights_preloaded) - - self._walk_full_post_load(model) + if mx_staged_receiver_path: + self._setup_aliases(model) + self._mark_weights_transformed(model) + self._walk_cache_state(model) + self._post_load_publish(checkpoint_loader, + model, + checkpoint_dir=checkpoint_dir, + weights_preloaded=weights_preloaded) + else: + self._walk_full_post_load(model) + self._post_load_publish(checkpoint_loader, + model, + checkpoint_dir=checkpoint_dir, + weights_preloaded=weights_preloaded) # TODO(GMS-MOE-LB): when the (MoE, GMS) combination is enabled, # `register_weight_slots_after_to_cuda` and `finalize_model` @@ -827,24 +866,113 @@ def _check_gms_source_identity(self, gms_backend) -> None: IdentityCheckPolicy.STRICT, ) + def _should_run_mx_staged_receiver_path( + self, checkpoint_loader: BaseCheckpointLoader, + model: DecoderModelForCausalLM, *, weights_preloaded: bool) -> bool: + """Whether an MX receiver can skip one-shot weight transforms. + + MXCheckpointLoader only accepts post-transform P2P bytes when this same + allow-list check passes before transfer, so this post-load branch should + never see a non-allow-listed post-transform receiver in normal use. + """ + if checkpoint_loader.checkpoint_format != "MX" or not weights_preloaded: + return False + + method = getattr(type(checkpoint_loader), + 'is_post_transform_weights_preloaded', None) + if method is None or not checkpoint_loader.is_post_transform_weights_preloaded( + ): + return False + + if self._can_accept_mx_post_transform_weights(model): + logger.info( + "MX receiver using staged post-load path for %s " + "(transform protocol v%d).", + type(model).__name__, + self._MX_STAGED_RECEIVER_TRANSFORM_PROTOCOL_VERSION, + ) + return True + + if self._has_separately_loaded_draft_weights(): + raise RuntimeError( + f"MX receiver got post-transform weights for {type(model).__name__}, " + "but staged MX receivers are disabled when speculative decoding " + "loads separate draft-model weights. Refusing to mix " + "post-transform primary weights with raw draft weights.") + + raise RuntimeError( + f"MX receiver got post-transform weights for {type(model).__name__}, " + "but the model is not allow-listed for staged post-load transform " + f"protocol v{self._MX_STAGED_RECEIVER_TRANSFORM_PROTOCOL_VERSION}. " + "Refusing to run the full post-load path on already-transformed " + "weights.") + + def _can_accept_mx_post_transform_weights( + self, model: DecoderModelForCausalLM) -> bool: + return (self._is_mx_staged_receiver_allowlisted(model) + and not self._has_separately_loaded_draft_weights()) + + def _has_separately_loaded_draft_weights(self) -> bool: + return (self.spec_config is not None + and self.spec_config.spec_dec_mode.need_load_draft_weights()) + + @classmethod + def _is_mx_staged_receiver_allowlisted( + cls, model: DecoderModelForCausalLM) -> bool: + for model_type, protocol_version in cls._MX_STAGED_RECEIVER_ALLOWLIST: + if (protocol_version + == cls._MX_STAGED_RECEIVER_TRANSFORM_PROTOCOL_VERSION + and isinstance(model, model_type)): + return True + return False + + def _post_load_publish(self, checkpoint_loader: BaseCheckpointLoader, + model: DecoderModelForCausalLM, *, + checkpoint_dir: str, + weights_preloaded: bool) -> None: + kwargs = { + "checkpoint_dir": checkpoint_dir, + "weights_preloaded": weights_preloaded, + } + if checkpoint_loader.checkpoint_format == "MX": + kwargs["source_identity"] = self._source_identity + checkpoint_loader.post_load_publish(model, **kwargs) + + @staticmethod + def _mark_weights_transformed(model: DecoderModelForCausalLM) -> None: + """Mark modules with transform guards as already transformed. + + Post-transform sharing paths skip ``transform_weights()`` because the + incoming bytes already use the final runtime layout. Preserve that + lifecycle state on modules that participate in the transform guard + protocol so a later orchestrator/refactor does not treat them as raw + checkpoint bytes. + """ + for module in model.modules(): + if hasattr(module, '_weights_transformed') and not getattr( + module, '_weights_removed', False): + module._weights_transformed = True + @staticmethod def _setup_aliases(model: DecoderModelForCausalLM) -> None: - """Run top-level structural alias setup if the model defines it. + """Run structural alias setup on eligible modules. - Alias wiring is a model-level concern. It is intentionally not a - recursive module walk, because migrated aliases are expected to be set - by the root model that owns the layer graph. + The walk is duck-typed so modules can opt in without inheriting a + shared base class. Modules whose weights were removed are skipped, + matching the legacy full post-load walk. Args: - model: Root decoder model whose top-level alias hook should run. + model: Root decoder model whose module tree should be visited. Returns: None. """ - setup_aliases: Optional[Callable[[], None]] = getattr( - model, 'setup_aliases', None) - if setup_aliases is not None: - setup_aliases() + for module in model.modules(): + setup_aliases: Optional[Callable[[], None]] = getattr( + module, 'setup_aliases', None) + if setup_aliases is not None and not getattr( + module, '_weights_removed', False): + setup_aliases() @staticmethod def _walk_transform(model: DecoderModelForCausalLM) -> None: @@ -935,8 +1063,11 @@ def reload(self, """Reload model weights without running post-load hooks. Reload is used by incremental update paths that may provide only a - partial set of replacement weights. The owner of the update lifecycle is - responsible for running post-load processing once all bytes are present. + partial set of replacement weights. Full reloads reset transform guards + before rebinding fresh weights. Partial reloads keep existing transform + guards intact because untouched modules may already contain transformed + live weights. The owner of the update lifecycle is responsible for + running post-load processing once all bytes are present. Args: model: Model instance receiving the replacement weights. @@ -952,6 +1083,8 @@ def reload(self, "Cannot reload weights: weight_mapper was not initialized. " "This can happen when the initial load used GMS, MX P2P, or " "VISION_ONLY, which bypass the standard weight mapping path.") + if not allow_partial_loading: + self._reset_weights_transformed(model) self._call_load_weights(model.load_weights, weights, self.weight_mapper, diff --git a/tests/unittest/_torch/attention/sparse/test_dsa_indexer.py b/tests/unittest/_torch/attention/sparse/test_dsa_indexer.py index 2a421a0894a3..8e968b5757f7 100644 --- a/tests/unittest/_torch/attention/sparse/test_dsa_indexer.py +++ b/tests/unittest/_torch/attention/sparse/test_dsa_indexer.py @@ -1,3 +1,17 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ Test suite for DeepGEMM indexer kernels and some related utilities. @@ -45,6 +59,23 @@ def has_deep_gemm(): return False +def test_indexer_post_load_weights_caches_fused_weight(): + indexer = Indexer.__new__(Indexer) + torch.nn.Module.__init__(indexer) + indexer.wk = torch.nn.Linear(3, 2, bias=False) + indexer.weights_proj = torch.nn.Linear(3, 4, bias=False) + indexer.wk.weight.data.fill_(1.0) + indexer.weights_proj.weight.data.fill_(2.0) + + indexer.post_load_weights() + + assert indexer._fused_wk_wp_weight.shape == (6, 3) + assert torch.equal(indexer._fused_wk_wp_weight[:2], indexer.wk.weight.data) + assert torch.equal(indexer._fused_wk_wp_weight[2:], + indexer.weights_proj.weight.data) + assert not hasattr(indexer, "_weights_transformed") + + def _ceil_to_ue8m0(x: torch.Tensor): """Round tensor values up to the nearest power of two (UE8M0 format).""" return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) diff --git a/tests/unittest/_torch/models/checkpoints/mx/test_mx_checkpoint_loader.py b/tests/unittest/_torch/models/checkpoints/mx/test_mx_checkpoint_loader.py index c686a1c664f7..c5c8d4408be7 100644 --- a/tests/unittest/_torch/models/checkpoints/mx/test_mx_checkpoint_loader.py +++ b/tests/unittest/_torch/models/checkpoints/mx/test_mx_checkpoint_loader.py @@ -14,9 +14,11 @@ about *our* fallback behavior, not about the upstream API. """ +import json import os import sys from contextlib import ExitStack +from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest @@ -29,10 +31,40 @@ ) from tensorrt_llm._torch.models.checkpoints.hf.weight_mapper import HfWeightMapper from tensorrt_llm._torch.models.checkpoints.mx.checkpoint_loader import ( + _MX_SOURCE_IDENTITY_METADATA_KEY, + _MX_STAGED_TRANSFORM_PROTOCOL_VERSION, + _MX_TRANSFORM_PROTOCOL_VERSION_METADATA_KEY, + _MX_WEIGHT_LAYOUT_METADATA_KEY, + _MX_WEIGHT_LAYOUT_POST_TRANSFORM, MXCheckpointLoader, + _build_mx_source_metadata, _normalize_model_identity, _resolve_mx_model_name, ) +from tensorrt_llm._torch.weight_sharing import SourceIdentity + +_MISSING = object() + + +def _identity(rank: int = 0) -> SourceIdentity: + return SourceIdentity( + format_version=1, + model_fingerprint="model", + quant_fingerprint="quant", + backend_fingerprint="backend", + parallel_fingerprint="parallel", + rank=rank, + shard_fingerprint=f"shard-{rank}", + model_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0", + ) + + +def _source_instance(identity: SourceIdentity, *, post_transform: bool = True): + metadata = _build_mx_source_metadata(identity) + if not post_transform: + metadata[_MX_WEIGHT_LAYOUT_METADATA_KEY] = "pre_transform" + return SimpleNamespace(metadata=metadata) + # --------------------------------------------------------------------------- # Construction & static properties @@ -44,11 +76,13 @@ def test_no_args_constructs(self): loader = MXCheckpointLoader() assert loader.mx_server_url is None assert loader.is_weights_preloaded() is False + assert loader.is_post_transform_weights_preloaded() is False def test_mx_server_url_stored(self): loader = MXCheckpointLoader(mx_server_url="http://mx:8001") assert loader.mx_server_url == "http://mx:8001" assert loader.is_weights_preloaded() is False + assert loader.is_post_transform_weights_preloaded() is False def test_query_timeout_stored(self): loader = MXCheckpointLoader(mx_server_url="http://mx:8001", query_timeout_s=900) @@ -75,6 +109,17 @@ def test_checkpoint_format_backing_attr(self): def test_is_weights_preloaded_initial(self): loader = MXCheckpointLoader() assert loader.is_weights_preloaded() is False + assert loader.is_post_transform_weights_preloaded() is False + + def test_post_transform_signal_requires_p2p_and_identity_match(self): + loader = MXCheckpointLoader() + loader._p2p_succeeded = True + loader._post_transform_weights_preloaded = True + loader._source_identity_compatible_for_last_load = False + assert loader.is_post_transform_weights_preloaded() is False + + loader._source_identity_compatible_for_last_load = True + assert loader.is_post_transform_weights_preloaded() is True # --------------------------------------------------------------------------- @@ -145,9 +190,14 @@ def _modelexpress_unavailable(stack): @staticmethod def _upstream_raises(stack): - fake_mx = _build_fake_modelexpress(load_weights_side_effect=RuntimeError("boom")) + identity = _identity() + loader = MXCheckpointLoader(mx_server_url="http://mx:8001") + fake_mx = _build_fake_modelexpress( + load_weights_side_effect=RuntimeError("boom"), + source_instances=[_source_instance(identity, post_transform=False)], + ) stack.enter_context(_install_fake_modelexpress(fake_mx)) - return (MXCheckpointLoader(mx_server_url="http://mx:8001"), {"model": MagicMock()}) + return (loader, {"model": MagicMock(), "source_identity": identity}) @pytest.mark.parametrize( "trigger_id, setup", @@ -195,16 +245,26 @@ def test_p2p_full_success_returns_empty_dict(self): # params; ``ModelLoader`` interprets the empty dict + the # ``is_weights_preloaded()`` signal as "skip the standard # weight-mapping pipeline". + identity = _identity() loader = MXCheckpointLoader(mx_server_url="http://mx:8001") - fake_mx = _build_fake_modelexpress(load_weights_return={}) + fake_mx = _build_fake_modelexpress( + load_weights_return={}, + source_instances=[_source_instance(identity, post_transform=False)], + ) mapping = MagicMock(name="mapping") model = MagicMock(name="model") with _install_fake_modelexpress(fake_mx): - result = loader.load_weights("/nonexistent", mapping=mapping, model=model) + result = loader.load_weights( + "/nonexistent", + mapping=mapping, + model=model, + source_identity=identity, + ) assert result == {} assert loader.is_weights_preloaded() is True + assert loader.is_post_transform_weights_preloaded() is False # Verify the integration contract with the upstream library: # 1. Constructed MxLiveWeightLoader with our mx_server_url. @@ -221,20 +281,218 @@ def test_mixed_success_returns_fallback_weights(self): # When MX returns a non-empty fallback dict (size-mismatched # tensors), keep the P2P transfer and let ModelLoader merge these # tensors through the standard disk pipeline. + identity = _identity() loader = MXCheckpointLoader(mx_server_url="http://mx:8001") fallback = {"some.weight": MagicMock()} - fake_mx = _build_fake_modelexpress(load_weights_return=fallback) + fake_mx = _build_fake_modelexpress( + load_weights_return=fallback, + source_instances=[_source_instance(identity, post_transform=False)], + ) with ( _install_fake_modelexpress(fake_mx), patch.object(HfCheckpointLoader, "load_weights") as mock_super_load, ): - result = loader.load_weights("/nonexistent", mapping=MagicMock(), model=MagicMock()) + result = loader.load_weights( + "/nonexistent", + mapping=MagicMock(), + model=MagicMock(), + source_identity=identity, + ) assert loader.is_weights_preloaded() is True + assert loader.is_post_transform_weights_preloaded() is False assert result is fallback mock_super_load.assert_not_called() + def test_post_transform_full_success_sets_skip_signal_when_allowlisted(self): + identity = _identity() + loader = MXCheckpointLoader(mx_server_url="http://mx:8001") + fake_mx = _build_fake_modelexpress( + load_weights_return={}, + source_instances=[_source_instance(identity)], + ) + + with _install_fake_modelexpress(fake_mx): + result = loader.load_weights( + "/nonexistent", + mapping=MagicMock(), + model=MagicMock(), + source_identity=identity, + allow_post_transform_weights=True, + ) + + assert result == {} + assert loader.is_weights_preloaded() is True + assert loader.is_post_transform_weights_preloaded() is True + + def test_post_transform_source_falls_back_before_p2p_when_not_allowlisted(self): + identity = _identity() + loader = MXCheckpointLoader(mx_server_url="http://mx:8001") + disk_weights = {"disk.weight": MagicMock()} + fake_mx = _build_fake_modelexpress( + load_weights_return={}, + source_instances=[_source_instance(identity)], + ) + + with ( + _install_fake_modelexpress(fake_mx), + patch.object( + HfCheckpointLoader, "load_weights", return_value=disk_weights + ) as mock_super_load, + ): + result = loader.load_weights( + "/nonexistent", + mapping=MagicMock(), + model=MagicMock(), + source_identity=identity, + allow_post_transform_weights=False, + ) + + assert result is disk_weights + assert loader.is_weights_preloaded() is False + assert loader.is_post_transform_weights_preloaded() is False + mx_loader = fake_mx.trtllm_live_transfer.MxLiveWeightLoader.return_value + mx_loader.load_weights.assert_not_called() + mock_super_load.assert_called_once() + + @pytest.mark.parametrize( + "protocol_value", + ["2", "not-an-int", _MISSING], + ids=["newer-protocol", "invalid-protocol", "missing-protocol"], + ) + def test_post_transform_source_with_unsupported_protocol_falls_back_before_p2p( + self, protocol_value + ): + identity = _identity() + loader = MXCheckpointLoader(mx_server_url="http://mx:8001") + disk_weights = {"disk.weight": MagicMock()} + source_instance = _source_instance(identity) + if protocol_value is _MISSING: + source_instance.metadata.pop(_MX_TRANSFORM_PROTOCOL_VERSION_METADATA_KEY) + else: + source_instance.metadata[_MX_TRANSFORM_PROTOCOL_VERSION_METADATA_KEY] = protocol_value + fake_mx = _build_fake_modelexpress( + load_weights_return={}, + source_instances=[source_instance], + ) + + with ( + _install_fake_modelexpress(fake_mx), + patch.object( + HfCheckpointLoader, "load_weights", return_value=disk_weights + ) as mock_super_load, + ): + result = loader.load_weights( + "/nonexistent", + mapping=MagicMock(), + model=MagicMock(), + source_identity=identity, + allow_post_transform_weights=True, + ) + + assert result is disk_weights + assert loader.is_weights_preloaded() is False + assert loader.is_post_transform_weights_preloaded() is False + mx_loader = fake_mx.trtllm_live_transfer.MxLiveWeightLoader.return_value + mx_loader.load_weights.assert_not_called() + mock_super_load.assert_called_once() + + def test_selects_matching_source_metadata_from_multiple_instances(self): + rank0_identity = _identity(rank=0) + rank1_identity = _identity(rank=1) + loader = MXCheckpointLoader(mx_server_url="http://mx:8001") + fake_mx = _build_fake_modelexpress( + load_weights_return={}, + source_instances=[ + _source_instance(rank0_identity), + _source_instance(rank1_identity), + ], + ) + + with ( + _install_fake_modelexpress(fake_mx), + patch.object(HfCheckpointLoader, "load_weights") as mock_super_load, + ): + result = loader.load_weights( + "/nonexistent", + mapping=MagicMock(), + model=MagicMock(), + source_identity=rank1_identity, + allow_post_transform_weights=True, + ) + + assert result == {} + assert loader.is_weights_preloaded() is True + assert loader.is_post_transform_weights_preloaded() is True + mock_super_load.assert_not_called() + + def test_post_transform_mixed_success_falls_back_to_full_disk_load(self): + # Wave 5 will let MX advertise post-transform sources. If such a + # source only partially succeeds, merging raw fallback tensors would + # force ModelLoader onto the full post-load path and double-transform + # the P2P subset. Lock the safer behavior now: abandon the partial + # post-transform transfer and return a full disk load instead. + identity = _identity() + loader = MXCheckpointLoader(mx_server_url="http://mx:8001") + fallback = {"some.weight": MagicMock(numel=lambda: 1, element_size=lambda: 4)} + disk_weights = {"disk.weight": MagicMock()} + fake_mx = _build_fake_modelexpress( + load_weights_return=fallback, + source_instances=[_source_instance(identity)], + ) + + with ( + _install_fake_modelexpress(fake_mx), + patch.object( + HfCheckpointLoader, "load_weights", return_value=disk_weights + ) as mock_super_load, + ): + result = loader.load_weights( + "/nonexistent", + mapping=MagicMock(), + model=MagicMock(), + source_identity=identity, + allow_post_transform_weights=True, + ) + + assert result is disk_weights + assert loader.is_weights_preloaded() is False + assert loader.is_post_transform_weights_preloaded() is False + mx_loader = fake_mx.trtllm_live_transfer.MxLiveWeightLoader.return_value + mx_loader.load_weights.assert_called_once() + mock_super_load.assert_called_once() + + def test_post_transform_source_can_be_disallowed_before_p2p(self): + # Some receiver shapes, such as target+draft speculative decoding, are + # not ready to mix post-transform target bytes with separately loaded + # raw draft bytes. Let ModelLoader force a disk fallback before MX + # starts RDMA, rather than accepting bytes it cannot safely stage. + loader = MXCheckpointLoader(mx_server_url="http://mx:8001") + loader._source_identity_compatible = MagicMock(return_value=True) + loader._source_metadata_is_post_transform = MagicMock(return_value=True) + disk_weights = {"disk.weight": MagicMock()} + fake_mx = _build_fake_modelexpress(load_weights_return={}) + + with ( + _install_fake_modelexpress(fake_mx), + patch.object( + HfCheckpointLoader, "load_weights", return_value=disk_weights + ) as mock_super_load, + ): + result = loader.load_weights( + "/nonexistent", + mapping=MagicMock(), + model=MagicMock(), + allow_post_transform_weights=False, + ) + + assert result is disk_weights + assert loader.is_weights_preloaded() is False + assert loader.is_post_transform_weights_preloaded() is False + fake_mx.trtllm_live_transfer.MxLiveWeightLoader.assert_not_called() + mock_super_load.assert_called_once() + # --------------------------------------------------------------------------- # publish_as_source — env-var dance and graceful no-op @@ -252,23 +510,57 @@ def test_no_mx_server_url_is_noop(self): def test_modelexpress_unavailable_is_noop(self): loader = MXCheckpointLoader(mx_server_url="http://mx:8001") with _block_modelexpress(): - loader.publish_as_source(MagicMock()) # must not raise + loader.publish_as_source(MagicMock(), source_identity=_identity()) # must not raise def test_publish_called_with_model(self): loader = MXCheckpointLoader(mx_server_url="http://mx:8001") + identity = _identity() fake_mx = _build_fake_modelexpress() model = MagicMock(name="model") with _install_fake_modelexpress(fake_mx): - loader.publish_as_source(model) + loader.publish_as_source(model, source_identity=identity) + + fake_mx.trtllm_live_transfer.publish_model_params.assert_called_once() + args, kwargs = fake_mx.trtllm_live_transfer.publish_model_params.call_args + assert args == (model,) + metadata = kwargs["metadata"] + assert metadata[_MX_WEIGHT_LAYOUT_METADATA_KEY] == _MX_WEIGHT_LAYOUT_POST_TRANSFORM + assert metadata[_MX_TRANSFORM_PROTOCOL_VERSION_METADATA_KEY] == str( + _MX_STAGED_TRANSFORM_PROTOCOL_VERSION + ) + assert _MX_SOURCE_IDENTITY_METADATA_KEY in metadata + assert metadata[_MX_SOURCE_IDENTITY_METADATA_KEY] == json.dumps( + identity.to_dict(), sort_keys=True + ) + + def test_source_identity_required_for_post_transform_publish(self): + loader = MXCheckpointLoader(mx_server_url="http://mx:8001") + fake_mx = _build_fake_modelexpress() + + with _install_fake_modelexpress(fake_mx): + loader.publish_as_source(MagicMock()) + + fake_mx.trtllm_live_transfer.publish_model_params.assert_not_called() + + def test_publish_skipped_when_metadata_unsupported(self): + loader = MXCheckpointLoader(mx_server_url="http://mx:8001") + calls = [] + + def _publish_without_metadata(model): + calls.append(model) + + fake_mx = _build_fake_modelexpress(publish_model_params=_publish_without_metadata) + with _install_fake_modelexpress(fake_mx): + loader.publish_as_source(MagicMock(), source_identity=_identity()) - fake_mx.trtllm_live_transfer.publish_model_params.assert_called_once_with(model) + assert calls == [] def test_env_var_set_during_publish_then_restored(self): loader = MXCheckpointLoader(mx_server_url="http://mx-instance:9999") captured_env = {} - def _capture(model): + def _capture(model, **_kwargs): captured_env["MODEL_EXPRESS_URL"] = os.environ.get("MODEL_EXPRESS_URL") fake_mx = _build_fake_modelexpress(publish_side_effect=_capture) @@ -276,7 +568,7 @@ def _capture(model): prior = os.environ.pop("MODEL_EXPRESS_URL", None) try: with _install_fake_modelexpress(fake_mx): - loader.publish_as_source(MagicMock()) + loader.publish_as_source(MagicMock(), source_identity=_identity()) finally: if prior is not None: os.environ["MODEL_EXPRESS_URL"] = prior @@ -293,7 +585,7 @@ def test_env_var_restored_to_prior_value(self): os.environ["MODEL_EXPRESS_URL"] = "http://prior-value:1234" try: with _install_fake_modelexpress(fake_mx): - loader.publish_as_source(MagicMock()) + loader.publish_as_source(MagicMock(), source_identity=_identity()) assert os.environ["MODEL_EXPRESS_URL"] == "http://prior-value:1234" finally: if prior is None: @@ -308,7 +600,7 @@ def test_publish_exception_swallowed(self): fake_mx = _build_fake_modelexpress(publish_side_effect=RuntimeError("upstream went away")) with _install_fake_modelexpress(fake_mx): - loader.publish_as_source(MagicMock()) # must not raise + loader.publish_as_source(MagicMock(), source_identity=_identity()) # must not raise # --------------------------------------------------------------------------- @@ -350,7 +642,9 @@ def _build_fake_modelexpress( load_weights_return=None, load_weights_side_effect=None, publish_side_effect=None, + publish_model_params=None, source_instances=None, + source_metadata=None, ): """Build a fake modelexpress module tree mimicking the symbols we use.""" fake_pkg = MagicMock(name="modelexpress") @@ -368,11 +662,15 @@ def _build_fake_modelexpress( fake_trtllm_live.MxLiveWeightLoader = MagicMock(return_value=weight_loader_instance) client_instance = MagicMock(name="MxClient instance") client_instance.list_sources.return_value = MagicMock(instances=source_instances or []) + if source_metadata is not None: + client_instance.get_source_metadata.return_value = source_metadata fake_trtllm_live.MxClient = MagicMock(return_value=client_instance) fake_trtllm_live._build_trtllm_identity = MagicMock(return_value=MagicMock()) # publish_model_params(model) - if publish_side_effect is not None: + if publish_model_params is not None: + fake_trtllm_live.publish_model_params = publish_model_params + elif publish_side_effect is not None: fake_trtllm_live.publish_model_params = MagicMock(side_effect=publish_side_effect) else: fake_trtllm_live.publish_model_params = MagicMock() @@ -423,17 +721,31 @@ def _isolated_env(self, monkeypatch): yield def test_no_registered_source_gets_short_default_during_load(self): + identity = _identity() + source_metadata = _build_mx_source_metadata(identity) + source_metadata[_MX_WEIGHT_LAYOUT_METADATA_KEY] = "pre_transform" + def _assert_timeout(*args, **kwargs): assert os.environ.get("MX_SOURCE_QUERY_TIMEOUT") == "30" return {} loader = MXCheckpointLoader(mx_server_url="http://mx:8001") - fake_mx = _build_fake_modelexpress(load_weights_side_effect=_assert_timeout) + fake_mx = _build_fake_modelexpress( + load_weights_side_effect=_assert_timeout, + source_metadata=source_metadata, + ) with _install_fake_modelexpress(fake_mx): - loader.load_weights("/nonexistent", mapping=MagicMock(), model=MagicMock()) + loader.load_weights( + "/nonexistent", + mapping=MagicMock(), + model=MagicMock(), + source_identity=identity, + ) assert "MX_SOURCE_QUERY_TIMEOUT" not in os.environ def test_existing_source_keeps_upstream_default_when_unset(self): + identity = _identity() + def _assert_no_timeout(*args, **kwargs): assert "MX_SOURCE_QUERY_TIMEOUT" not in os.environ return {} @@ -441,13 +753,19 @@ def _assert_no_timeout(*args, **kwargs): loader = MXCheckpointLoader(mx_server_url="http://mx:8001") fake_mx = _build_fake_modelexpress( load_weights_side_effect=_assert_no_timeout, - source_instances=[MagicMock()], + source_instances=[_source_instance(identity, post_transform=False)], ) with _install_fake_modelexpress(fake_mx): - loader.load_weights("/nonexistent", mapping=MagicMock(), model=MagicMock()) + loader.load_weights( + "/nonexistent", + mapping=MagicMock(), + model=MagicMock(), + source_identity=identity, + ) assert "MX_SOURCE_QUERY_TIMEOUT" not in os.environ def test_env_value_preserved(self, monkeypatch): + identity = _identity() # If the user/orchestrator already set a value, our defensive # default must not stomp it. monkeypatch.setenv("MX_SOURCE_QUERY_TIMEOUT", "120") @@ -457,20 +775,38 @@ def _assert_env_timeout(*args, **kwargs): return {} loader = MXCheckpointLoader(mx_server_url="http://mx:8001") - fake_mx = _build_fake_modelexpress(load_weights_side_effect=_assert_env_timeout) + fake_mx = _build_fake_modelexpress( + load_weights_side_effect=_assert_env_timeout, + source_instances=[_source_instance(identity, post_transform=False)], + ) with _install_fake_modelexpress(fake_mx): - loader.load_weights("/nonexistent", mapping=MagicMock(), model=MagicMock()) + loader.load_weights( + "/nonexistent", + mapping=MagicMock(), + model=MagicMock(), + source_identity=identity, + ) assert os.environ.get("MX_SOURCE_QUERY_TIMEOUT") == "120" def test_configured_timeout_applies_during_load_and_restores_env(self): + identity = _identity() + def _assert_config_timeout(*args, **kwargs): assert os.environ.get("MX_SOURCE_QUERY_TIMEOUT") == "900" return {} loader = MXCheckpointLoader(mx_server_url="http://mx:8001", query_timeout_s=900) - fake_mx = _build_fake_modelexpress(load_weights_side_effect=_assert_config_timeout) + fake_mx = _build_fake_modelexpress( + load_weights_side_effect=_assert_config_timeout, + source_instances=[_source_instance(identity, post_transform=False)], + ) with _install_fake_modelexpress(fake_mx): - loader.load_weights("/nonexistent", mapping=MagicMock(), model=MagicMock()) + loader.load_weights( + "/nonexistent", + mapping=MagicMock(), + model=MagicMock(), + source_identity=identity, + ) assert "MX_SOURCE_QUERY_TIMEOUT" not in os.environ def test_no_mx_url_does_not_touch_env(self): @@ -605,13 +941,13 @@ def test_uses_explicit_constructor_model_name(self): ) captured = {} - def _capture(model): + def _capture(model, **_kwargs): captured["MODEL_NAME"] = os.environ.get("MODEL_NAME") captured["MODEL_EXPRESS_URL"] = os.environ.get("MODEL_EXPRESS_URL") fake_mx = _build_fake_modelexpress(publish_side_effect=_capture) with _install_fake_modelexpress(fake_mx): - loader.publish_as_source(MagicMock()) + loader.publish_as_source(MagicMock(), source_identity=_identity()) assert captured["MODEL_NAME"] == "Qwen/Qwen2.5-72B-Instruct" assert captured["MODEL_EXPRESS_URL"] == "http://mx:8001" @@ -624,12 +960,16 @@ def test_falls_back_to_checkpoint_dir_basename(self): # No constructor model_name, no MODEL_NAME env → use basename. captured = {} - def _capture(model): + def _capture(model, **_kwargs): captured["MODEL_NAME"] = os.environ.get("MODEL_NAME") fake_mx = _build_fake_modelexpress(publish_side_effect=_capture) with _install_fake_modelexpress(fake_mx): - loader.publish_as_source(MagicMock(), checkpoint_dir="/scratch/local-model") + loader.publish_as_source( + MagicMock(), + checkpoint_dir="/scratch/local-model", + source_identity=_identity(), + ) assert captured["MODEL_NAME"] == "local-model" @@ -640,12 +980,16 @@ def test_unmangles_hf_snapshot_path(self): ) captured = {} - def _capture(model): + def _capture(model, **_kwargs): captured["MODEL_NAME"] = os.environ.get("MODEL_NAME") fake_mx = _build_fake_modelexpress(publish_side_effect=_capture) with _install_fake_modelexpress(fake_mx): - loader.publish_as_source(MagicMock(), checkpoint_dir=snapshot) + loader.publish_as_source( + MagicMock(), + checkpoint_dir=snapshot, + source_identity=_identity(), + ) # Critical: NOT the commit hash, the human-readable Hub-ID form. assert captured["MODEL_NAME"] == "Qwen/Qwen2.5-72B-Instruct" @@ -660,12 +1004,12 @@ def test_constructor_model_name_takes_priority_over_env(self, monkeypatch): ) captured = {} - def _capture(model): + def _capture(model, **_kwargs): captured["MODEL_NAME"] = os.environ.get("MODEL_NAME") fake_mx = _build_fake_modelexpress(publish_side_effect=_capture) with _install_fake_modelexpress(fake_mx): - loader.publish_as_source(MagicMock()) + loader.publish_as_source(MagicMock(), source_identity=_identity()) assert captured["MODEL_NAME"] == "explicit" # Restored to the prior env value, not unset. @@ -676,12 +1020,12 @@ def test_env_used_when_no_constructor_value(self, monkeypatch): loader = MXCheckpointLoader(mx_server_url="http://mx:8001") captured = {} - def _capture(model): + def _capture(model, **_kwargs): captured["MODEL_NAME"] = os.environ.get("MODEL_NAME") fake_mx = _build_fake_modelexpress(publish_side_effect=_capture) with _install_fake_modelexpress(fake_mx): - loader.publish_as_source(MagicMock()) + loader.publish_as_source(MagicMock(), source_identity=_identity()) assert captured["MODEL_NAME"] == "from-env-only" assert os.environ.get("MODEL_NAME") == "from-env-only" @@ -691,11 +1035,11 @@ def test_unknown_when_all_sources_missing(self): loader = MXCheckpointLoader(mx_server_url="http://mx:8001") captured = {} - def _capture(model): + def _capture(model, **_kwargs): captured["MODEL_NAME"] = os.environ.get("MODEL_NAME") fake_mx = _build_fake_modelexpress(publish_side_effect=_capture) with _install_fake_modelexpress(fake_mx): - loader.publish_as_source(MagicMock()) + loader.publish_as_source(MagicMock(), source_identity=_identity()) assert captured["MODEL_NAME"] == "unknown" diff --git a/tests/unittest/_torch/modules/mamba/test_mamba2_mixer.py b/tests/unittest/_torch/modules/mamba/test_mamba2_mixer.py new file mode 100644 index 000000000000..84d184708726 --- /dev/null +++ b/tests/unittest/_torch/modules/mamba/test_mamba2_mixer.py @@ -0,0 +1,41 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from types import SimpleNamespace + +import pytest +import torch +from torch import nn + +mamba2_mixer_mod = pytest.importorskip("tensorrt_llm._torch.modules.mamba.mamba2_mixer") +Mamba2Mixer = mamba2_mixer_mod.Mamba2Mixer + + +def test_mamba2_mixer_post_load_weights_caches_derived_state(): + mixer = Mamba2Mixer.__new__(Mamba2Mixer) + nn.Module.__init__(mixer) + mixer.norm = SimpleNamespace(is_nvfp4=False) + mixer.A = torch.tensor([1.0, 2.0]) + mixer.dt_bias = torch.tensor([3.0, 4.0]) + mixer.D = torch.tensor([5.0, 6.0]) + mixer.head_dim = 2 + mixer.d_state = 3 + + mixer.post_load_weights() + + assert mixer._A_expanded.shape == (2, 2, 3) + assert mixer._dt_bias_expanded.shape == (2, 2) + assert mixer._D_expanded.shape == (2, 2) + assert not hasattr(mixer, "_weights_transformed") diff --git a/tests/unittest/_torch/modules/moe/test_moe_backend.py b/tests/unittest/_torch/modules/moe/test_moe_backend.py index 8006a8c890c8..de9579ba1c17 100644 --- a/tests/unittest/_torch/modules/moe/test_moe_backend.py +++ b/tests/unittest/_torch/modules/moe/test_moe_backend.py @@ -29,7 +29,9 @@ import itertools import logging import os +from types import SimpleNamespace from typing import List, Optional +from unittest.mock import MagicMock import pytest import torch @@ -66,7 +68,6 @@ logger = logging.getLogger(__name__) - _MEGAMOE_BACKEND_TYPES = { MoeBackendType.MEGAMOE_DEEPGEMM, MoeBackendType.MEGAMOE_CUTEDSL, @@ -185,6 +186,131 @@ def create_test_backend( ) +def test_moe_post_load_weights_uses_idempotent_transform_hook(): + class HookTestMoE(MoE): + def create_weights(self): + raise NotImplementedError + + def load_weights(self, weights, allow_partial_loading=False): + raise NotImplementedError + + def quantize_input(self, x, **kwargs): + return x, None + + def run_moe(self, **kwargs): + raise NotImplementedError + + moe = HookTestMoE.__new__(HookTestMoE) + torch.nn.Module.__init__(moe) + quant_method = SimpleNamespace( + transform_weights=MagicMock(), + cache_derived_state=MagicMock(), + ) + moe.quant_method = quant_method + + moe.post_load_weights() + moe.transform_weights() + + quant_method.transform_weights.assert_called_once_with(moe) + quant_method.cache_derived_state.assert_called_once_with(moe) + assert moe._weights_transformed is True + + moe.cache_derived_state() + assert quant_method.cache_derived_state.call_count == 2 + + moe._weights_transformed = False + moe.transform_weights() + assert quant_method.transform_weights.call_count == 2 + + +def test_configurable_moe_post_load_weights_uses_backend_staged_hooks(): + from tensorrt_llm._torch.modules.fused_moe.configurable_moe import ConfigurableMoE + + class HookTestConfigurableMoE(ConfigurableMoE): + def quantize_input(self, x, **kwargs): + return x, None + + def run_moe(self, **kwargs): + raise NotImplementedError + + configurable_moe = HookTestConfigurableMoE.__new__(HookTestConfigurableMoE) + torch.nn.Module.__init__(configurable_moe) + backend = torch.nn.Module() + backend.load_weights = MagicMock() + backend.transform_weights = MagicMock() + backend.cache_derived_state = MagicMock() + configurable_moe.backend = backend + + configurable_moe.post_load_weights() + configurable_moe.transform_weights() + + backend.transform_weights.assert_called_once_with() + backend.cache_derived_state.assert_called_once_with() + assert configurable_moe._weights_transformed is True + + configurable_moe.cache_derived_state() + assert backend.cache_derived_state.call_count == 2 + + configurable_moe.load_weights([{"fresh": "weights"}]) + backend.load_weights.assert_called_once_with([{"fresh": "weights"}], False) + assert configurable_moe._weights_transformed is False + + +def test_megamoe_load_weights_invalidates_cached_deepgemm_views(): + method = W4A8MXFP4MXFP8MegaMoEDeepGemmMethod() + hidden_size = 128 + intermediate_size = 128 + module = SimpleNamespace( + weight_loading_mode=MoEWeightLoadingMode.VANILLA, + initial_local_expert_ids=[0], + w3_w1_weight=torch.empty(1, intermediate_size * 2, hidden_size // 2, dtype=torch.uint8), + w3_w1_weight_scale=torch.empty( + 1, intermediate_size * 2, hidden_size // 32, dtype=torch.uint8 + ), + w2_weight=torch.empty(1, hidden_size, intermediate_size // 2, dtype=torch.uint8), + w2_weight_scale=torch.empty(1, hidden_size, intermediate_size // 32, dtype=torch.uint8), + _t_l1=(torch.empty(1), torch.empty(1)), + _t_l2=(torch.empty(1), torch.empty(1)), + _t_l1_weight=torch.empty(1), + _t_l1_scale=torch.empty(1), + _t_l1_scale_slot=torch.empty(1), + _t_l2_weight=torch.empty(1), + _t_l2_scale=torch.empty(1), + _t_l2_scale_slot=torch.empty(1), + ) + weights = { + "0.w1.weight": torch.full((intermediate_size, hidden_size // 2), 1, dtype=torch.uint8), + "0.w3.weight": torch.full((intermediate_size, hidden_size // 2), 2, dtype=torch.uint8), + "0.w2.weight": torch.full((hidden_size, intermediate_size // 2), 3, dtype=torch.uint8), + "0.w1.weight_scale": torch.full( + (intermediate_size, hidden_size // 32), 4, dtype=torch.uint8 + ), + "0.w3.weight_scale": torch.full( + (intermediate_size, hidden_size // 32), 5, dtype=torch.uint8 + ), + "0.w2.weight_scale": torch.full( + (hidden_size, intermediate_size // 32), 6, dtype=torch.uint8 + ), + } + + method.load_weights(module, [weights]) + + assert module.w3_w1_weight[0, 0, 0].item() == 1 + assert module.w3_w1_weight[0, intermediate_size, 0].item() == 2 + assert module._weights_loaded is True + for attr in ( + "_t_l1", + "_t_l2", + "_t_l1_weight", + "_t_l1_scale", + "_t_l1_scale_slot", + "_t_l2_weight", + "_t_l2_scale", + "_t_l2_scale_slot", + ): + assert getattr(module, attr) is None + + def test_megamoe_init_rejects_uneven_num_slots_with_value_error(): routing_method = RenormalizeMoeRoutingMethod(top_k=1) model_config = ModelConfig( diff --git a/tests/unittest/_torch/pyexecutor/test_model_loader_gms.py b/tests/unittest/_torch/pyexecutor/test_model_loader_gms.py index f429f1c5f4d7..a1a6e1b9b700 100644 --- a/tests/unittest/_torch/pyexecutor/test_model_loader_gms.py +++ b/tests/unittest/_torch/pyexecutor/test_model_loader_gms.py @@ -29,6 +29,7 @@ class _TinyModel(nn.Module): def __init__(self, events, *, include_draft=False): super().__init__() + self._weights_transformed = False self._events = events if include_draft: self.draft_model = nn.Module() @@ -44,7 +45,13 @@ def to(self, *args, **kwargs): def load_weights(self, weights, mapper): self._events.append("load_weights") - def post_load_weights(self): + def setup_aliases(self) -> None: + self._events.append("setup_aliases") + + def cache_derived_state(self) -> None: + self._events.append("cache_derived_state") + + def post_load_weights(self) -> None: self._events.append("post_load_weights") @@ -76,7 +83,9 @@ def _make_loader(monkeypatch, *, events, spec_config=None): loader._call_load_weights = MagicMock( side_effect=lambda fn, weights, mapper, **kwargs: fn(weights, mapper) ) - loader._load_and_validate_config = MagicMock(return_value=SimpleNamespace(name="config")) + loader._load_and_validate_config = MagicMock( + return_value=SimpleNamespace(name="config", mapping=SimpleNamespace()) + ) monkeypatch.setattr(model_loader_mod, "timing", lambda *_args, **_kwargs: nullcontext()) monkeypatch.setattr(model_loader_mod, "maybe_create_moe_load_balancer", _moe_context) @@ -123,6 +132,19 @@ def _install_gms_backend(monkeypatch, backend): monkeypatch.setattr(memory_mod, "GMSBackend", MagicMock(return_value=backend)) +class _PostTransformMxLoader: + checkpoint_format = "MX" + + def __init__(self) -> None: + self.load_weights = MagicMock(return_value={}) + self.is_weights_preloaded = MagicMock(return_value=True) + self.post_load_apply = MagicMock() + self.post_load_publish = MagicMock() + + def is_post_transform_weights_preloaded(self) -> bool: + return True + + def _spec_config_needing_draft_weights(): return SimpleNamespace( spec_dec_mode=SimpleNamespace(need_load_draft_weights=lambda: True), @@ -147,7 +169,7 @@ def _spec_config_needing_draft_weights(): ), pytest.param( False, - ["post_load_weights", "materialize"], + ["setup_aliases", "materialize", "cache_derived_state"], id="ro", ), ], @@ -161,8 +183,9 @@ def test_gms_load_branch(monkeypatch, is_rw, expected_events): (``_apply`` for meta materialization, ``to('cuda')``, weight load, ``post_load_weights``) inside the pool, then commits via ``finalize_write`` once the scope exits. - ro: the reader runs ``post_load_weights`` to wire module aliases - first, then GMS materializes weights via zero-copy mapping. + ro: the reader runs ``setup_aliases`` to wire module aliases, checks + identity compatibility, materializes weights via zero-copy mapping, + then refreshes derived state from real tensors. """ events = [] loader = _make_loader(monkeypatch, events=events) @@ -186,6 +209,8 @@ def test_gms_load_branch(monkeypatch, is_rw, expected_events): # ``model=model`` is passed for symmetry with the LoadFormat.AUTO # path (see model_loader.py); HF ignores it, MX uses it for direct # P2P writes when MX+GMS composition eventually lands. + # ``source_identity`` is included so format-specific loaders can + # publish the same compatibility fingerprint the RO path validates. checkpoint_loader.load_weights.assert_called_once_with( "/ckpt", mapping=loader.mapping, @@ -196,13 +221,55 @@ def test_gms_load_branch(monkeypatch, is_rw, expected_events): backend.move_untracked_params.assert_called_once_with(model) backend.finalize_write.assert_called_once_with(model) else: - # RO: post_load_weights() must run before the GMS materialize - # step so module aliases are wired up before zero-copy mapping. + # RO: setup_aliases() must run before the GMS materialize step so + # module aliases are wired up before zero-copy mapping. checkpoint_loader.load_weights.assert_not_called() loader._call_load_weights.assert_not_called() backend.materialize_module.assert_called_once_with(model) +def test_gms_ro_materializes_between_alias_setup_and_cache_state(monkeypatch): + events = [] + loader = _make_loader(monkeypatch, events=events) + backend = _build_gms_backend(is_rw=False, events=events) + _install_gms_backend(monkeypatch, backend) + + checkpoint_loader = MagicMock(name="checkpoint_loader") + checkpoint_loader.checkpoint_format = "HF" + + def record(event): + def _append(*_args, **_kwargs): + events.append(event) + + return _append + + checkpoint_loader.post_load_apply.side_effect = record("post_load_apply") + checkpoint_loader.post_load_publish.side_effect = record("post_load_publish") + + # The STRICT pre-materialize identity gate runs between alias setup and + # materialization; record it to pin the ordering without exercising the + # comparison logic, which is covered in test_source_identity.py. + monkeypatch.setattr( + model_loader_mod, + "check_weight_sharing_compatibility", + lambda *_args, **_kwargs: events.append("check_source_identity"), + ) + + loader.load("/ckpt", checkpoint_loader) + + assert events == [ + "post_load_apply", + "setup_aliases", + "check_source_identity", + "materialize", + "cache_derived_state", + "post_load_publish", + ] + assert "post_load_weights" not in events + checkpoint_loader.load_weights.assert_not_called() + backend.materialize_module.assert_called_once() + + def test_gms_rw_post_load_runs_inside_pool_before_finalize(monkeypatch): """Every step that may allocate or rebind tensors must run inside the GMS pool. @@ -240,9 +307,9 @@ def test_gms_rw_post_load_runs_inside_pool_before_finalize(monkeypatch): "to", "load_weights", "post_load_apply", - "post_load_publish", "post_load_weights", "move_untracked_params", + "post_load_publish", "pool_exit", "finalize_write", ] @@ -294,6 +361,58 @@ def test_gms_rw_loader_preload_skips_mapping_pipeline(monkeypatch): backend.finalize_write.assert_called_once_with(model) +def test_gms_rw_mx_post_transform_preload_uses_staged_path(monkeypatch): + """GMS writers that receive post-transform MX bytes must not transform again.""" + events = [] + loader = _make_loader(monkeypatch, events=events) + monkeypatch.setattr( + ModelLoader, + "_MX_STAGED_RECEIVER_ALLOWLIST", + frozenset({(_TinyModel, ModelLoader._MX_STAGED_RECEIVER_TRANSFORM_PROTOCOL_VERSION)}), + ) + backend = _build_gms_backend(is_rw=True, events=events) + backend.move_untracked_params.side_effect = lambda _model: events.append( + "move_untracked_params" + ) + backend.finalize_write.side_effect = lambda _model: events.append("finalize_write") + _install_gms_backend(monkeypatch, backend) + + checkpoint_loader = _PostTransformMxLoader() + checkpoint_loader.post_load_apply.side_effect = lambda *_a, **_kw: events.append( + "post_load_apply" + ) + checkpoint_loader.post_load_publish.side_effect = lambda *_a, **_kw: events.append( + "post_load_publish" + ) + + model, _ = loader.load("/ckpt", checkpoint_loader) + + assert events == [ + "pool_enter", + "_apply", + "to", + "post_load_apply", + "setup_aliases", + "cache_derived_state", + "move_untracked_params", + "post_load_publish", + "pool_exit", + "finalize_write", + ] + assert "post_load_weights" not in events + assert model._weights_transformed is True + _args, kwargs = checkpoint_loader.load_weights.call_args + assert kwargs["allow_post_transform_weights"] is True + loader._call_load_weights.assert_not_called() + checkpoint_loader.post_load_publish.assert_called_once_with( + model, + checkpoint_dir="/ckpt", + weights_preloaded=True, + source_identity=loader._source_identity, + ) + backend.finalize_write.assert_called_once_with(model) + + def test_gms_rw_no_load_and_no_preload_raises(monkeypatch): """RW + empty ``weights`` + ``is_weights_preloaded()=False`` is a bug. diff --git a/tests/unittest/_torch/pyexecutor/test_model_loader_mx.py b/tests/unittest/_torch/pyexecutor/test_model_loader_mx.py index 0ffe6fb32a2a..15f55393948f 100644 --- a/tests/unittest/_torch/pyexecutor/test_model_loader_mx.py +++ b/tests/unittest/_torch/pyexecutor/test_model_loader_mx.py @@ -6,9 +6,16 @@ from types import SimpleNamespace from unittest.mock import MagicMock +import pytest import torch from torch import nn +from transformers import LlamaConfig +from tensorrt_llm._torch.model_config import ModelConfig +from tensorrt_llm._torch.models import modeling_llama as modeling_llama_mod +from tensorrt_llm._torch.modules import attention as attention_mod +from tensorrt_llm._torch.modules.attention import MLA +from tensorrt_llm._torch.modules.linear import Linear from tensorrt_llm._torch.pyexecutor import model_loader as model_loader_mod from tensorrt_llm._torch.pyexecutor.model_loader import ModelLoader from tensorrt_llm.llmapi.llm_args import LoadFormat @@ -25,6 +32,10 @@ class _LinearStub(nn.Module): + def __init__(self): + super().__init__() + self._weights_transformed = False + def post_load_weights(self): pass @@ -38,6 +49,7 @@ def __init__(self): class _TinyModel(nn.Module): def __init__(self, events): super().__init__() + self._weights_transformed = False self.linear = _LinearStub() self.draft_model = _DraftModel() self.draft_config = SimpleNamespace( @@ -58,6 +70,12 @@ def load_weights(self, weights, mapper): def load_draft_weights(self, weights, mapper): self._events.append("load_draft_weights") + def setup_aliases(self): + self._events.append("setup_aliases") + + def cache_derived_state(self): + self._events.append("cache_derived_state") + def post_load_weights(self): self._events.append("post_load_weights") @@ -67,6 +85,54 @@ def _moe_context(config, mapping): yield None +def _tiny_llama_model(monkeypatch): + monkeypatch.setattr(modeling_llama_mod, "get_sm_version", lambda: 90) + llama_config = LlamaConfig( + architectures=["LlamaForCausalLM"], + attention_bias=False, + hidden_act="silu", + hidden_size=16, + intermediate_size=32, + max_position_embeddings=16, + mlp_bias=False, + num_attention_heads=2, + num_hidden_layers=2, + num_key_value_heads=2, + rms_norm_eps=1e-5, + tie_word_embeddings=False, + torch_dtype=torch.float32, + vocab_size=32, + ) + return model_loader_mod.LlamaForCausalLM( + ModelConfig( + pretrained_config=llama_config, + max_num_tokens=16, + max_seq_len=16, + ) + ) + + +def _llama_alias_state(model): + layers = model.model.layers + return { + "skip_norm": model.model.skip_norm, + "layer0_next_norm": layers[0].next_layer_layernorm is layers[1].input_layernorm, + "layer0_next_attn": layers[0].next_attn is layers[1].self_attn, + "layer1_skip_input_norm": layers[1].skip_input_layernorm, + "layer1_next_norm": layers[1].next_layer_layernorm is model.model.norm, + "layer1_next_attn": layers[1].next_attn is None, + } + + +def _transform_guard_state(model): + return { + name: module._weights_transformed + for name, module in model.named_modules() + if hasattr(module, "_weights_transformed") + and not getattr(module, "_weights_removed", False) + } + + def _make_loader(monkeypatch, *, events, spec_config=None): llm_args = SimpleNamespace(load_format=LoadFormat.AUTO) loader = ModelLoader( @@ -108,6 +174,13 @@ def _make_loader(monkeypatch, *, events, spec_config=None): return loader +def _spec_config_needing_draft_weights() -> SimpleNamespace: + return SimpleNamespace( + spec_dec_mode=SimpleNamespace(need_load_draft_weights=lambda: True), + speculative_model="/draft-ckpt", + ) + + def test_mx_success_initializes_mapper_skips_weight_mapping_and_reload_works(monkeypatch): events = [] loader = _make_loader(monkeypatch, events=events) @@ -123,20 +196,45 @@ def test_mx_success_initializes_mapper_skips_weight_mapping_and_reload_works(mon assert kwargs["mapping"] is loader.mapping assert kwargs["model"] is model assert kwargs["source_identity"] is loader._source_identity + assert kwargs["allow_post_transform_weights"] is False assert loader._call_load_weights.call_count == 0 checkpoint_loader.get_initialized_weight_mapper.assert_called_once() assert loader.weight_mapper is checkpoint_loader.get_initialized_weight_mapper.return_value checkpoint_loader.post_load_publish.assert_called_once_with( - model, checkpoint_dir="/ckpt", weights_preloaded=True + model, + checkpoint_dir="/ckpt", + weights_preloaded=True, + source_identity=loader._source_identity, ) # reload() uses self.weight_mapper unconditionally; MX success must # initialize it even though the initial load skipped _call_load_weights. + model._weights_transformed = True + model.linear._weights_transformed = True loader.reload(model, {"reloaded": MagicMock()}) assert loader._call_load_weights.call_count == 1 + assert model._weights_transformed is False + assert model.linear._weights_transformed is False assert events == ["post_load_weights", "load_weights"] +def test_reload_partial_loading_preserves_weights_transformed_flags(monkeypatch): + events = [] + loader = _make_loader(monkeypatch, events=events) + loader.weight_mapper = MagicMock(name="weight_mapper") + model = _TinyModel(events) + model._weights_transformed = True + model.linear._weights_transformed = True + + loader.reload(model, {"reloaded": MagicMock()}, allow_partial_loading=True) + + assert loader._call_load_weights.call_count == 1 + assert loader._call_load_weights.call_args.kwargs["allow_partial_loading"] is True + assert model._weights_transformed is True + assert model.linear._weights_transformed is True + assert events == ["load_weights"] + + def test_mx_partial_fallback_merges_returned_weights(monkeypatch): events = [] loader = _make_loader(monkeypatch, events=events) @@ -154,8 +252,157 @@ def test_mx_partial_fallback_merges_returned_weights(monkeypatch): assert weights is fallback_weights assert mapper is loader.weight_mapper checkpoint_loader.post_load_publish.assert_called_once_with( - model, checkpoint_dir="/ckpt", weights_preloaded=True + model, + checkpoint_dir="/ckpt", + weights_preloaded=True, + source_identity=loader._source_identity, + ) + + +class _PostTransformMxLoader: + checkpoint_format = "MX" + + def __init__(self, *, post_transform: bool) -> None: + self._post_transform = post_transform + self._weights_preloaded = True + self.load_weights = MagicMock(side_effect=self._load_weights) + self.is_weights_preloaded = MagicMock(side_effect=lambda: self._weights_preloaded) + self.get_initialized_weight_mapper = MagicMock(return_value=MagicMock()) + self.post_load_apply = MagicMock() + self.post_load_publish = MagicMock() + + def _load_weights(self, *_args, **kwargs): + if self._post_transform and kwargs.get("allow_post_transform_weights") is False: + self._post_transform = False + self._weights_preloaded = False + return {"disk.weight": MagicMock()} + return {} + + def is_post_transform_weights_preloaded(self) -> bool: + return self._post_transform + + +def test_mx_post_transform_receiver_uses_staged_path_when_allowlisted(monkeypatch): + events = [] + loader = _make_loader(monkeypatch, events=events) + monkeypatch.setattr( + ModelLoader, + "_MX_STAGED_RECEIVER_ALLOWLIST", + frozenset({(_TinyModel, ModelLoader._MX_STAGED_RECEIVER_TRANSFORM_PROTOCOL_VERSION)}), + ) + checkpoint_loader = _PostTransformMxLoader(post_transform=True) + + model, _ = loader.load("/ckpt", checkpoint_loader) + + loader._call_load_weights.assert_not_called() + _args, kwargs = checkpoint_loader.load_weights.call_args + assert kwargs["allow_post_transform_weights"] is True + checkpoint_loader.post_load_publish.assert_called_once_with( + model, + checkpoint_dir="/ckpt", + weights_preloaded=True, + source_identity=loader._source_identity, + ) + # Post-transform receivers skip transform_weights(), but the accepted + # tensors are already in final layout. Keep the transform guard in sync so + # future reload/refactor paths do not accidentally treat them as raw bytes. + assert model._weights_transformed is True + assert model.linear._weights_transformed is True + assert model.draft_model.linear._weights_transformed is True + assert events == ["setup_aliases", "cache_derived_state"] + + +def test_default_mx_staged_receiver_allowlist_matches_real_tiny_llama(monkeypatch): + full_model = _tiny_llama_model(monkeypatch) + staged_model = _tiny_llama_model(monkeypatch) + checkpoint_loader = _PostTransformMxLoader(post_transform=True) + loader = _make_loader(monkeypatch, events=[]) + + ModelLoader._walk_full_post_load(full_model) + assert ( + loader._should_run_mx_staged_receiver_path( + checkpoint_loader, + staged_model, + weights_preloaded=True, + ) + is True + ) + + transform_weights_calls = [] + for module in staged_model.modules(): + transform_weights = getattr(module, "transform_weights", None) + if transform_weights is None: + continue + + def _record_transform_weights(*_args, module=module, **_kwargs): + transform_weights_calls.append(type(module).__name__) + + module.transform_weights = _record_transform_weights + + ModelLoader._setup_aliases(staged_model) + ModelLoader._mark_weights_transformed(staged_model) + ModelLoader._walk_cache_state(staged_model) + + assert transform_weights_calls == [] + assert _llama_alias_state(staged_model) == _llama_alias_state(full_model) + assert _transform_guard_state(staged_model) == _transform_guard_state(full_model) + + +def test_mx_post_transform_receiver_disabled_when_draft_weights_load_separately( + monkeypatch, +) -> None: + events = [] + loader = _make_loader( + monkeypatch, + events=events, + spec_config=_spec_config_needing_draft_weights(), + ) + monkeypatch.setattr( + ModelLoader, + "_MX_STAGED_RECEIVER_ALLOWLIST", + frozenset({(_TinyModel, ModelLoader._MX_STAGED_RECEIVER_TRANSFORM_PROTOCOL_VERSION)}), + ) + draft_mapper = MagicMock() + monkeypatch.setattr( + model_loader_mod.AutoCheckpointMapper, + "get", + MagicMock(return_value=draft_mapper), + ) + checkpoint_loader = MagicMock(name="checkpoint_loader") + checkpoint_loader.checkpoint_format = "MX" + checkpoint_loader.load_weights.side_effect = [ + {"primary.weight": MagicMock()}, + {"draft.weight": MagicMock()}, + ] + checkpoint_loader.is_weights_preloaded.return_value = False + checkpoint_loader.is_post_transform_weights_preloaded.return_value = False + checkpoint_loader.get_initialized_weight_mapper.return_value = MagicMock() + + model, _ = loader.load("/ckpt", checkpoint_loader) + + primary_call = checkpoint_loader.load_weights.call_args_list[0] + assert primary_call.kwargs["allow_post_transform_weights"] is False + assert loader._call_load_weights.call_count == 2 + checkpoint_loader.post_load_publish.assert_called_once_with( + model, + checkpoint_dir="/ckpt", + weights_preloaded=False, + source_identity=loader._source_identity, ) + assert events == ["load_weights", "load_draft_weights", "post_load_weights"] + + +def test_mx_post_transform_receiver_rejects_non_allowlisted_model(monkeypatch): + events = [] + loader = _make_loader(monkeypatch, events=events) + checkpoint_loader = _PostTransformMxLoader(post_transform=True) + + with pytest.raises(RuntimeError, match="not allow-listed"): + loader.load("/ckpt", checkpoint_loader) + + _args, kwargs = checkpoint_loader.load_weights.call_args + assert kwargs["allow_post_transform_weights"] is False + assert "post_load_weights" not in events def test_mx_fallback_runs_standard_weight_mapping(monkeypatch): @@ -173,7 +420,10 @@ def test_mx_fallback_runs_standard_weight_mapping(monkeypatch): assert events[0] == "load_weights" assert "post_load_weights" in events checkpoint_loader.post_load_publish.assert_called_once_with( - model, checkpoint_dir="/ckpt", weights_preloaded=False + model, + checkpoint_dir="/ckpt", + weights_preloaded=False, + source_identity=loader._source_identity, ) @@ -194,17 +444,17 @@ def __init__( if transformed is not None: self._weights_transformed = transformed - def setup_aliases(self): + def setup_aliases(self) -> None: self.events.append((self.name, "setup_aliases")) - def transform_weights(self): + def transform_weights(self) -> None: self.events.append((self.name, "transform_weights")) self._weights_transformed = True - def cache_derived_state(self): + def cache_derived_state(self) -> None: self.events.append((self.name, "cache_derived_state")) - def post_load_weights(self): + def post_load_weights(self) -> None: self.events.append((self.name, "post_load_weights")) @@ -216,13 +466,17 @@ def __init__(self, events): self.removed_child = _HookRecorder("removed_child", events, removed=True) -def test_staged_hook_setup_aliases_is_top_level_only(): +def test_staged_hook_setup_aliases_walks_skip_removed_modules(): events = [] model = _HookModel(events) ModelLoader._setup_aliases(model) - assert events == [("model", "setup_aliases")] + assert events == [ + ("model", "setup_aliases"), + ("child", "setup_aliases"), + ("transformed_child", "setup_aliases"), + ] def test_staged_hook_walks_skip_removed_and_transformed_modules(): @@ -257,3 +511,78 @@ def test_reset_weights_transformed_only_resets_existing_flags(): assert model.child._weights_transformed is False assert model.transformed_child._weights_transformed is False assert not hasattr(model.removed_child, "_weights_transformed") + + +def test_mark_weights_transformed_only_sets_existing_flags(): + events = [] + model = _HookModel(events) + model._weights_transformed = False + model.child._weights_transformed = False + + ModelLoader._mark_weights_transformed(model) + + assert model._weights_transformed is True + assert model.child._weights_transformed is True + assert model.transformed_child._weights_transformed is True + assert not hasattr(model.removed_child, "_weights_transformed") + + +def test_linear_transform_weights_is_idempotent(): + linear = Linear( + 1, + 1, + bias=False, + reduce_output=False, + skip_create_weights_in_init=True, + ) + linear.quant_method = MagicMock() + + linear.transform_weights() + linear.post_load_weights() + + linear.quant_method.transform_weights.assert_called_once_with(linear) + assert linear._weights_transformed is True + + linear._weights_transformed = False + linear.post_load_weights() + assert linear.quant_method.transform_weights.call_count == 2 + + linear._weights_transformed = False + linear.cache_derived_state() + assert linear._weights_transformed is True + + +def test_mla_transform_weights_is_idempotent(monkeypatch): + monkeypatch.setattr(attention_mod, "get_sm_version", lambda: 120) + quant_mode = SimpleNamespace(has_fp8_block_scales=lambda: True) + mla = MLA.__new__(MLA) + mla._weights_transformed = False + mla.kv_b_proj = SimpleNamespace(quant_config=SimpleNamespace(quant_mode=quant_mode)) + mla.k_b_proj_trans = "k_weight" + mla.k_b_proj_trans_scale = "k_scale" + mla.v_b_proj = "v_weight" + mla.v_b_proj_scale = "v_scale" + calls = [] + + def fake_resmooth(weight, scale, recipe): + calls.append((weight, scale, recipe)) + return f"{weight}_transformed", f"{scale}_transformed" + + mla.resmooth_parameters = fake_resmooth + + MLA.transform_weights(mla) + MLA.post_load_weights(mla) + + assert calls == [ + ("k_weight", "k_scale", (1, 128, 128)), + ("v_weight", "v_scale", (1, 128, 128)), + ] + assert mla.k_b_proj_trans == "k_weight_transformed" + assert mla.k_b_proj_trans_scale == "k_scale_transformed" + assert mla.v_b_proj == "v_weight_transformed" + assert mla.v_b_proj_scale == "v_scale_transformed" + assert mla._weights_transformed is True + + mla._weights_transformed = False + MLA.cache_derived_state(mla) + assert mla._weights_transformed is True diff --git a/tests/unittest/_torch/weight_sharing/test_mx_source_identity_gate.py b/tests/unittest/_torch/weight_sharing/test_mx_source_identity_gate.py index ded26762baf7..4eb1b566df56 100644 --- a/tests/unittest/_torch/weight_sharing/test_mx_source_identity_gate.py +++ b/tests/unittest/_torch/weight_sharing/test_mx_source_identity_gate.py @@ -22,10 +22,15 @@ without any model, GPU, or RDMA. """ +from types import SimpleNamespace + from _source_identity_fakes import FakeMapping from _source_identity_fakes import make_identity as _identity -from tensorrt_llm._torch.models.checkpoints.mx.checkpoint_loader import MXCheckpointLoader +from tensorrt_llm._torch.models.checkpoints.mx.checkpoint_loader import ( + MXCheckpointLoader, + _build_mx_source_metadata, +) def _new_loader(local_identity, source_identity, fetched=True): @@ -70,11 +75,36 @@ def test_gate_falls_back_when_source_identity_unavailable(): assert loader._source_identity_compatible("ckpt", _STUB_CLIENT, _STUB_BUILD) is False -def test_fetch_source_identity_returns_none_until_upstream_wired(): - # The seam currently returns None by contract; this pins that behavior so - # the day it is wired to real MX metadata, the change is deliberate. +def test_fetch_source_identity_returns_none_when_metadata_unavailable(): + loader = MXCheckpointLoader.__new__(MXCheckpointLoader) + loader._mx_server_url = "http://mx:8001" + loader._model_name = None + + class _Client: + def __init__(self, *, server_url): + self.server_url = server_url + + def list_sources(self, *, identity): + return SimpleNamespace(instances=[]) + + assert loader._fetch_source_identity("ckpt", _Client, lambda **_kw: object()) is None + + +def test_fetch_source_identity_from_source_metadata(): + source = _identity() loader = MXCheckpointLoader.__new__(MXCheckpointLoader) - assert loader._fetch_source_identity("ckpt", _STUB_CLIENT, _STUB_BUILD) is None + loader._mx_server_url = "http://mx:8001" + loader._model_name = None + + class _Client: + def __init__(self, *, server_url): + self.server_url = server_url + + def list_sources(self, *, identity): + instance = SimpleNamespace(metadata=_build_mx_source_metadata(source)) + return SimpleNamespace(instances=[instance]) + + assert loader._fetch_source_identity("ckpt", _Client, lambda **_kw: object()) == source def test_load_weights_pops_source_identity_kwarg():