Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions tensorrt_llm/_torch/memory/gpu_memory_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@
whose `post_load_weights()` is pure alias wiring; models that additionally
rely on plain Python attributes set inside `post_load_weights()` (rather
than registered `nn.Buffer` / `nn.Parameter` assignments) need to migrate
those side effects to `cache_derived_state()` or another path that runs on
RO readers. The GMS RO reader runs `setup_aliases()` before
`materialize_module()` and `cache_derived_state()` afterward; it does not
those side effects to `cache_derived_state()` or another hook that runs on
RO readers. One-shot tensor layout changes belong in `transform_weights()`
on the writer; the GMS RO reader runs `setup_aliases()` before
`materialize_module()`, then `cache_derived_state()` afterward. It does not
run `transform_weights()`.
"""

Expand Down
13 changes: 12 additions & 1 deletion tensorrt_llm/_torch/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,6 +1237,7 @@ def __init__(
self.layer_idx = layer_idx
self.layer_idx_str = str(layer_idx)
self.dtype = dtype
self._weights_transformed = False

self.hidden_size = hidden_size
self.num_heads = num_attention_heads
Expand Down Expand Up @@ -1624,6 +1625,7 @@ def create_weights(self):
else:
self.k_b_proj_trans_scale = None
self.v_b_proj_scale = None
self._weights_transformed = False

def apply_rope(
self,
Expand Down Expand Up @@ -3003,7 +3005,9 @@ def resmooth_parameters(self,

return weight_param, scale_param

def post_load_weights(self):
def transform_weights(self) -> None:
if self._weights_transformed:
return
has_fp8_block_scales = (
self.kv_b_proj.quant_config
and self.kv_b_proj.quant_config.quant_mode.has_fp8_block_scales())
Expand All @@ -3016,3 +3020,10 @@ def post_load_weights(self):

self.v_b_proj, self.v_b_proj_scale = self.resmooth_parameters(
self.v_b_proj, self.v_b_proj_scale, recipe=(1, 128, 128))
self._weights_transformed = True

def cache_derived_state(self) -> None:
self._weights_transformed = True

def post_load_weights(self) -> None:
self.transform_weights()
35 changes: 25 additions & 10 deletions tensorrt_llm/_torch/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,11 @@ def load_weights(self,
if not allow_partial_loading:
self.process_weights_after_loading(module)

def post_load_weights(self, module: Linear):
pass
def transform_weights(self, module: Linear) -> None:
...

def post_load_weights(self, module: Linear) -> None:
self.transform_weights(module)

def load_weight_scales(self, weights: List[Dict], *args, **kwargs):
"""
Expand Down Expand Up @@ -1241,8 +1244,8 @@ def load_weights_fused_gate_up_linear(
copy_weight_shard(module.weight_scale, scale, shard_offset,
shard_size)

def post_load_weights(self, module: Linear):
super().post_load_weights(module)
def transform_weights(self, module: Linear) -> None:
super().transform_weights(module)
if (is_sm_100f() and not (module.use_cute_dsl_blockscaling_mm
or module.disable_deep_gemm)) or \
get_sm_version() == 120:
Expand Down Expand Up @@ -1821,9 +1824,9 @@ def process_weights_after_loading_fused_gate_up_linear(
torch.ops.trtllm.block_scale_interleave(ws_swapped),
requires_grad=False)

def post_load_weights(self, module: Linear):
def transform_weights(self, module: Linear) -> None:
"""Pad weight and weight_scale tensors to meet torch trtllm NVFP4 GEMM alignment requirements."""
super().post_load_weights(module)
super().transform_weights(module)
row_alignment, col_alignment = 32, 16
row_pad_size = (row_alignment - module.weight.size(0)) % row_alignment
col_pad_size = (col_alignment - module.weight.size(1)) % col_alignment
Expand Down Expand Up @@ -1873,10 +1876,10 @@ class W4A16NVFP4LinearMethod(NVFP4LinearMethod):
its fused path is SM>=100-gated upstream.
"""

def post_load_weights(self, module: Linear):
def transform_weights(self, module: Linear) -> None:
# Skip parent's 32x16 weight padding (apply() accepts [N, K/2] as-is)
# and un-swizzle per-block scale once at load.
LinearMethodBase.post_load_weights(self, module)
LinearMethodBase.transform_weights(self, module)
pad_rows = fp4_utils.pad_up(module.out_features, 128)
pad_cols = fp4_utils.pad_up(
module.in_features // module.scaling_vector_size, 4)
Expand Down Expand Up @@ -2914,6 +2917,7 @@ def __init__(
dtype=self.dtype) if reduce_output else None

self._weights_created = False
self._weights_transformed = False
self.reduce_output = reduce_output
self.use_custom_cublas_mm = use_custom_cublas_mm
self.use_cute_dsl_bf16_gemm = use_cute_dsl_bf16_gemm
Expand Down Expand Up @@ -2966,6 +2970,7 @@ def create_weights(self):
self.dtype)

self._weights_created = True
self._weights_transformed = False

@property
def has_any_quant(self):
Expand Down Expand Up @@ -3127,6 +3132,7 @@ def load_weights(self,
assert allow_partial_loading is False, (
f"{type(self.quant_method).__name__} does not support "
"allow_partial_loading")
self._weights_transformed = False
self.quant_method.load_weights(
self,
weights,
Expand All @@ -3136,8 +3142,17 @@ def load_weights(self,
def process_weights_after_loading(self):
self.quant_method.process_weights_after_loading(self)

def post_load_weights(self):
self.quant_method.post_load_weights(self)
def transform_weights(self) -> None:
if self._weights_transformed:
return
self.quant_method.transform_weights(self)
self._weights_transformed = True
Comment thread
chienchunhung marked this conversation as resolved.

def cache_derived_state(self) -> None:
self._weights_transformed = True

def post_load_weights(self) -> None:
self.transform_weights()

def pre_reload_weights(self):
assert hasattr(
Expand Down
2 changes: 2 additions & 0 deletions tests/unittest/_torch/pyexecutor/test_model_loader_gms.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,8 @@ def test_gms_load_branch(monkeypatch, is_rw, expected_events):
# ``model=model`` is passed for symmetry with the LoadFormat.AUTO
# path (see model_loader.py); HF ignores it, MX uses it for direct
# P2P writes when MX+GMS composition eventually lands.
# ``source_identity`` is included so format-specific loaders can
# publish the same compatibility fingerprint the RO path validates.
checkpoint_loader.load_weights.assert_called_once_with(
"/ckpt",
mapping=loader.mapping,
Expand Down
64 changes: 64 additions & 0 deletions tests/unittest/_torch/pyexecutor/test_model_loader_mx.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
import torch
from torch import nn

from tensorrt_llm._torch.modules import attention as attention_mod
from tensorrt_llm._torch.modules.attention import MLA
from tensorrt_llm._torch.modules.linear import Linear
from tensorrt_llm._torch.pyexecutor import model_loader as model_loader_mod
from tensorrt_llm._torch.pyexecutor.model_loader import ModelLoader
from tensorrt_llm.llmapi.llm_args import LoadFormat
Expand Down Expand Up @@ -282,3 +285,64 @@ def test_reset_weights_transformed_only_resets_existing_flags():
assert model.child._weights_transformed is False
assert model.transformed_child._weights_transformed is False
assert not hasattr(model.removed_child, "_weights_transformed")


def test_linear_transform_weights_is_idempotent():
linear = Linear(
1,
1,
bias=False,
reduce_output=False,
skip_create_weights_in_init=True,
)
linear.quant_method = MagicMock()

linear.transform_weights()
linear.post_load_weights()

linear.quant_method.transform_weights.assert_called_once_with(linear)
assert linear._weights_transformed is True

linear._weights_transformed = False
linear.post_load_weights()
assert linear.quant_method.transform_weights.call_count == 2

linear._weights_transformed = False
linear.cache_derived_state()
assert linear._weights_transformed is True


def test_mla_transform_weights_is_idempotent(monkeypatch):
monkeypatch.setattr(attention_mod, "get_sm_version", lambda: 120)
quant_mode = SimpleNamespace(has_fp8_block_scales=lambda: True)
mla = MLA.__new__(MLA)
mla._weights_transformed = False
mla.kv_b_proj = SimpleNamespace(quant_config=SimpleNamespace(quant_mode=quant_mode))
mla.k_b_proj_trans = "k_weight"
mla.k_b_proj_trans_scale = "k_scale"
mla.v_b_proj = "v_weight"
mla.v_b_proj_scale = "v_scale"
calls = []

def fake_resmooth(weight, scale, recipe):
calls.append((weight, scale, recipe))
return f"{weight}_transformed", f"{scale}_transformed"

mla.resmooth_parameters = fake_resmooth

MLA.transform_weights(mla)
MLA.post_load_weights(mla)

assert calls == [
("k_weight", "k_scale", (1, 128, 128)),
("v_weight", "v_scale", (1, 128, 128)),
]
assert mla.k_b_proj_trans == "k_weight_transformed"
assert mla.k_b_proj_trans_scale == "k_scale_transformed"
assert mla.v_b_proj == "v_weight_transformed"
assert mla.v_b_proj_scale == "v_scale_transformed"
assert mla._weights_transformed is True

mla._weights_transformed = False
MLA.cache_derived_state(mla)
assert mla._weights_transformed is True
Loading