From ffd91fb455b47ab275c473c6aebc9c4ae829ce67 Mon Sep 17 00:00:00 2001 From: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com> Date: Thu, 11 Jun 2026 18:07:52 -0700 Subject: [PATCH 1/2] [TRTLLM-13247][feat] Wave 2: stage Linear and Attention transforms Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com> --- .../_torch/memory/gpu_memory_backend.py | 7 ++- tensorrt_llm/_torch/modules/attention.py | 10 +++- tensorrt_llm/_torch/modules/linear.py | 30 +++++++--- .../pyexecutor/test_model_loader_gms.py | 2 + .../_torch/pyexecutor/test_model_loader_mx.py | 56 +++++++++++++++++++ 5 files changed, 92 insertions(+), 13 deletions(-) diff --git a/tensorrt_llm/_torch/memory/gpu_memory_backend.py b/tensorrt_llm/_torch/memory/gpu_memory_backend.py index 12edcd071998..153f0f9c5c79 100644 --- a/tensorrt_llm/_torch/memory/gpu_memory_backend.py +++ b/tensorrt_llm/_torch/memory/gpu_memory_backend.py @@ -42,9 +42,10 @@ whose `post_load_weights()` is pure alias wiring; models that additionally rely on plain Python attributes set inside `post_load_weights()` (rather than registered `nn.Buffer` / `nn.Parameter` assignments) need to migrate - those side effects to `cache_derived_state()` or another path that runs on - RO readers. The GMS RO reader runs `setup_aliases()` before - `materialize_module()` and `cache_derived_state()` afterward; it does not + those side effects to `cache_derived_state()` or another hook that runs on + RO readers. One-shot tensor layout changes belong in `transform_weights()` + on the writer; the GMS RO reader runs `setup_aliases()` before + `materialize_module()`, then `cache_derived_state()` afterward. It does not run `transform_weights()`. """ diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index cc0694728829..d5b75b661ca3 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -1237,6 +1237,7 @@ def __init__( self.layer_idx = layer_idx self.layer_idx_str = str(layer_idx) self.dtype = dtype + self._weights_transformed = False self.hidden_size = hidden_size self.num_heads = num_attention_heads @@ -1624,6 +1625,7 @@ def create_weights(self): else: self.k_b_proj_trans_scale = None self.v_b_proj_scale = None + self._weights_transformed = False def apply_rope( self, @@ -3003,7 +3005,9 @@ def resmooth_parameters(self, return weight_param, scale_param - def post_load_weights(self): + def transform_weights(self) -> None: + if self._weights_transformed: + return has_fp8_block_scales = ( self.kv_b_proj.quant_config and self.kv_b_proj.quant_config.quant_mode.has_fp8_block_scales()) @@ -3016,3 +3020,7 @@ def post_load_weights(self): self.v_b_proj, self.v_b_proj_scale = self.resmooth_parameters( self.v_b_proj, self.v_b_proj_scale, recipe=(1, 128, 128)) + self._weights_transformed = True + + def post_load_weights(self) -> None: + self.transform_weights() diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index aae3f3a65ab2..67d521e1cc07 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -380,9 +380,12 @@ def load_weights(self, if not allow_partial_loading: self.process_weights_after_loading(module) - def post_load_weights(self, module: Linear): + def transform_weights(self, module: Linear) -> None: pass + def post_load_weights(self, module: Linear) -> None: + self.transform_weights(module) + def load_weight_scales(self, weights: List[Dict], *args, **kwargs): """ Load quantized weight scales from the checkpoint. @@ -1241,8 +1244,8 @@ def load_weights_fused_gate_up_linear( copy_weight_shard(module.weight_scale, scale, shard_offset, shard_size) - def post_load_weights(self, module: Linear): - super().post_load_weights(module) + def transform_weights(self, module: Linear) -> None: + super().transform_weights(module) if (is_sm_100f() and not (module.use_cute_dsl_blockscaling_mm or module.disable_deep_gemm)) or \ get_sm_version() == 120: @@ -1821,9 +1824,9 @@ def process_weights_after_loading_fused_gate_up_linear( torch.ops.trtllm.block_scale_interleave(ws_swapped), requires_grad=False) - def post_load_weights(self, module: Linear): + def transform_weights(self, module: Linear) -> None: """Pad weight and weight_scale tensors to meet torch trtllm NVFP4 GEMM alignment requirements.""" - super().post_load_weights(module) + super().transform_weights(module) row_alignment, col_alignment = 32, 16 row_pad_size = (row_alignment - module.weight.size(0)) % row_alignment col_pad_size = (col_alignment - module.weight.size(1)) % col_alignment @@ -1873,10 +1876,10 @@ class W4A16NVFP4LinearMethod(NVFP4LinearMethod): its fused path is SM>=100-gated upstream. """ - def post_load_weights(self, module: Linear): + def transform_weights(self, module: Linear) -> None: # Skip parent's 32x16 weight padding (apply() accepts [N, K/2] as-is) # and un-swizzle per-block scale once at load. - LinearMethodBase.post_load_weights(self, module) + LinearMethodBase.transform_weights(self, module) pad_rows = fp4_utils.pad_up(module.out_features, 128) pad_cols = fp4_utils.pad_up( module.in_features // module.scaling_vector_size, 4) @@ -2914,6 +2917,7 @@ def __init__( dtype=self.dtype) if reduce_output else None self._weights_created = False + self._weights_transformed = False self.reduce_output = reduce_output self.use_custom_cublas_mm = use_custom_cublas_mm self.use_cute_dsl_bf16_gemm = use_cute_dsl_bf16_gemm @@ -2966,6 +2970,7 @@ def create_weights(self): self.dtype) self._weights_created = True + self._weights_transformed = False @property def has_any_quant(self): @@ -3127,6 +3132,7 @@ def load_weights(self, assert allow_partial_loading is False, ( f"{type(self.quant_method).__name__} does not support " "allow_partial_loading") + self._weights_transformed = False self.quant_method.load_weights( self, weights, @@ -3136,8 +3142,14 @@ def load_weights(self, def process_weights_after_loading(self): self.quant_method.process_weights_after_loading(self) - def post_load_weights(self): - self.quant_method.post_load_weights(self) + def transform_weights(self) -> None: + if self._weights_transformed: + return + self.quant_method.transform_weights(self) + self._weights_transformed = True + + def post_load_weights(self) -> None: + self.transform_weights() def pre_reload_weights(self): assert hasattr( diff --git a/tests/unittest/_torch/pyexecutor/test_model_loader_gms.py b/tests/unittest/_torch/pyexecutor/test_model_loader_gms.py index 3fcc353a9fa2..d84296a1ba95 100644 --- a/tests/unittest/_torch/pyexecutor/test_model_loader_gms.py +++ b/tests/unittest/_torch/pyexecutor/test_model_loader_gms.py @@ -195,6 +195,8 @@ def test_gms_load_branch(monkeypatch, is_rw, expected_events): # ``model=model`` is passed for symmetry with the LoadFormat.AUTO # path (see model_loader.py); HF ignores it, MX uses it for direct # P2P writes when MX+GMS composition eventually lands. + # ``source_identity`` is included so format-specific loaders can + # publish the same compatibility fingerprint the RO path validates. checkpoint_loader.load_weights.assert_called_once_with( "/ckpt", mapping=loader.mapping, diff --git a/tests/unittest/_torch/pyexecutor/test_model_loader_mx.py b/tests/unittest/_torch/pyexecutor/test_model_loader_mx.py index c110cb1a6204..a3c7027d143d 100644 --- a/tests/unittest/_torch/pyexecutor/test_model_loader_mx.py +++ b/tests/unittest/_torch/pyexecutor/test_model_loader_mx.py @@ -9,6 +9,9 @@ import torch from torch import nn +from tensorrt_llm._torch.modules import attention as attention_mod +from tensorrt_llm._torch.modules.attention import MLA +from tensorrt_llm._torch.modules.linear import Linear from tensorrt_llm._torch.pyexecutor import model_loader as model_loader_mod from tensorrt_llm._torch.pyexecutor.model_loader import ModelLoader from tensorrt_llm.llmapi.llm_args import LoadFormat @@ -282,3 +285,56 @@ def test_reset_weights_transformed_only_resets_existing_flags(): assert model.child._weights_transformed is False assert model.transformed_child._weights_transformed is False assert not hasattr(model.removed_child, "_weights_transformed") + + +def test_linear_transform_weights_is_idempotent(): + linear = Linear( + 1, + 1, + bias=False, + reduce_output=False, + skip_create_weights_in_init=True, + ) + linear.quant_method = MagicMock() + + linear.transform_weights() + linear.post_load_weights() + + linear.quant_method.transform_weights.assert_called_once_with(linear) + assert linear._weights_transformed is True + + linear._weights_transformed = False + linear.post_load_weights() + assert linear.quant_method.transform_weights.call_count == 2 + + +def test_mla_transform_weights_is_idempotent(monkeypatch): + monkeypatch.setattr(attention_mod, "get_sm_version", lambda: 120) + quant_mode = SimpleNamespace(has_fp8_block_scales=lambda: True) + mla = MLA.__new__(MLA) + mla._weights_transformed = False + mla.kv_b_proj = SimpleNamespace(quant_config=SimpleNamespace(quant_mode=quant_mode)) + mla.k_b_proj_trans = "k_weight" + mla.k_b_proj_trans_scale = "k_scale" + mla.v_b_proj = "v_weight" + mla.v_b_proj_scale = "v_scale" + calls = [] + + def fake_resmooth(weight, scale, recipe): + calls.append((weight, scale, recipe)) + return f"{weight}_transformed", f"{scale}_transformed" + + mla.resmooth_parameters = fake_resmooth + + MLA.transform_weights(mla) + MLA.post_load_weights(mla) + + assert calls == [ + ("k_weight", "k_scale", (1, 128, 128)), + ("v_weight", "v_scale", (1, 128, 128)), + ] + assert mla.k_b_proj_trans == "k_weight_transformed" + assert mla.k_b_proj_trans_scale == "k_scale_transformed" + assert mla.v_b_proj == "v_weight_transformed" + assert mla.v_b_proj_scale == "v_scale_transformed" + assert mla._weights_transformed is True From 9450f16d7ab83678b9521eae3c43792b3cd76450 Mon Sep 17 00:00:00 2001 From: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com> Date: Tue, 23 Jun 2026 10:22:30 -0700 Subject: [PATCH 2/2] [TRTLLM-13247][fix] Address CodeRabbit review comments Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com> --- tensorrt_llm/_torch/modules/attention.py | 3 +++ tensorrt_llm/_torch/modules/linear.py | 5 ++++- tests/unittest/_torch/pyexecutor/test_model_loader_mx.py | 8 ++++++++ 3 files changed, 15 insertions(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index d5b75b661ca3..5e743e95120e 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -3022,5 +3022,8 @@ def transform_weights(self) -> None: self.v_b_proj, self.v_b_proj_scale, recipe=(1, 128, 128)) self._weights_transformed = True + def cache_derived_state(self) -> None: + self._weights_transformed = True + def post_load_weights(self) -> None: self.transform_weights() diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index 67d521e1cc07..9c52ed377d8a 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -381,7 +381,7 @@ def load_weights(self, self.process_weights_after_loading(module) def transform_weights(self, module: Linear) -> None: - pass + ... def post_load_weights(self, module: Linear) -> None: self.transform_weights(module) @@ -3148,6 +3148,9 @@ def transform_weights(self) -> None: self.quant_method.transform_weights(self) self._weights_transformed = True + def cache_derived_state(self) -> None: + self._weights_transformed = True + def post_load_weights(self) -> None: self.transform_weights() diff --git a/tests/unittest/_torch/pyexecutor/test_model_loader_mx.py b/tests/unittest/_torch/pyexecutor/test_model_loader_mx.py index a3c7027d143d..fcd96364b56c 100644 --- a/tests/unittest/_torch/pyexecutor/test_model_loader_mx.py +++ b/tests/unittest/_torch/pyexecutor/test_model_loader_mx.py @@ -307,6 +307,10 @@ def test_linear_transform_weights_is_idempotent(): linear.post_load_weights() assert linear.quant_method.transform_weights.call_count == 2 + linear._weights_transformed = False + linear.cache_derived_state() + assert linear._weights_transformed is True + def test_mla_transform_weights_is_idempotent(monkeypatch): monkeypatch.setattr(attention_mod, "get_sm_version", lambda: 120) @@ -338,3 +342,7 @@ def fake_resmooth(weight, scale, recipe): assert mla.v_b_proj == "v_weight_transformed" assert mla.v_b_proj_scale == "v_scale_transformed" assert mla._weights_transformed is True + + mla._weights_transformed = False + MLA.cache_derived_state(mla) + assert mla._weights_transformed is True