Skip to content

Commit cf883fc

Browse files
committed
[TRTLLM-13246][feat] Wave 2: stage Linear and Attention transforms
Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
1 parent d07a5cd commit cf883fc

5 files changed

Lines changed: 96 additions & 13 deletions

File tree

tensorrt_llm/_torch/memory/gpu_memory_backend.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,10 @@
4242
whose `post_load_weights()` is pure alias wiring; models that additionally
4343
rely on plain Python attributes set inside `post_load_weights()` (rather
4444
than registered `nn.Buffer` / `nn.Parameter` assignments) need to migrate
45-
those side effects to `transform_weights()` or `cache_derived_state()`
46-
before they are safe on the RO path.
45+
those side effects to `cache_derived_state()` or another hook that runs on
46+
the RO reader. One-shot tensor layout changes belong in `transform_weights()`
47+
on the writer; RO runs `setup_aliases()`, `materialize_module()`, then
48+
`cache_derived_state()`.
4749
"""
4850

4951
from contextlib import contextmanager

tensorrt_llm/_torch/modules/attention.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1245,6 +1245,7 @@ def __init__(
12451245
self.layer_idx = layer_idx
12461246
self.layer_idx_str = str(layer_idx)
12471247
self.dtype = dtype
1248+
self._weights_transformed = False
12481249

12491250
self.hidden_size = hidden_size
12501251
self.num_heads = num_attention_heads
@@ -1619,6 +1620,7 @@ def create_weights(self):
16191620
else:
16201621
self.k_b_proj_trans_scale = None
16211622
self.v_b_proj_scale = None
1623+
self._weights_transformed = False
16221624

16231625
def apply_rope(
16241626
self,
@@ -2998,7 +3000,9 @@ def resmooth_parameters(self,
29983000

29993001
return weight_param, scale_param
30003002

3001-
def post_load_weights(self):
3003+
def transform_weights(self) -> None:
3004+
if self._weights_transformed:
3005+
return
30023006
has_fp8_block_scales = (
30033007
self.kv_b_proj.quant_config
30043008
and self.kv_b_proj.quant_config.quant_mode.has_fp8_block_scales())
@@ -3011,3 +3015,7 @@ def post_load_weights(self):
30113015

30123016
self.v_b_proj, self.v_b_proj_scale = self.resmooth_parameters(
30133017
self.v_b_proj, self.v_b_proj_scale, recipe=(1, 128, 128))
3018+
self._weights_transformed = True
3019+
3020+
def post_load_weights(self) -> None:
3021+
self.transform_weights()

tensorrt_llm/_torch/modules/linear.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -380,9 +380,12 @@ def load_weights(self,
380380
if not allow_partial_loading:
381381
self.process_weights_after_loading(module)
382382

383-
def post_load_weights(self, module: Linear):
383+
def transform_weights(self, module: Linear) -> None:
384384
pass
385385

386+
def post_load_weights(self, module: Linear) -> None:
387+
self.transform_weights(module)
388+
386389
def load_weight_scales(self, weights: List[Dict], *args, **kwargs):
387390
"""
388391
Load quantized weight scales from the checkpoint.
@@ -1241,8 +1244,8 @@ def load_weights_fused_gate_up_linear(
12411244
copy_weight_shard(module.weight_scale, scale, shard_offset,
12421245
shard_size)
12431246

1244-
def post_load_weights(self, module: Linear):
1245-
super().post_load_weights(module)
1247+
def transform_weights(self, module: Linear) -> None:
1248+
super().transform_weights(module)
12461249
if (is_sm_100f() and not (module.use_cute_dsl_blockscaling_mm
12471250
or module.disable_deep_gemm)) or \
12481251
get_sm_version() == 120:
@@ -1821,9 +1824,9 @@ def process_weights_after_loading_fused_gate_up_linear(
18211824
torch.ops.trtllm.block_scale_interleave(ws_swapped),
18221825
requires_grad=False)
18231826

1824-
def post_load_weights(self, module: Linear):
1827+
def transform_weights(self, module: Linear) -> None:
18251828
"""Pad weight and weight_scale tensors to meet torch trtllm NVFP4 GEMM alignment requirements."""
1826-
super().post_load_weights(module)
1829+
super().transform_weights(module)
18271830
row_alignment, col_alignment = 32, 16
18281831
row_pad_size = (row_alignment - module.weight.size(0)) % row_alignment
18291832
col_pad_size = (col_alignment - module.weight.size(1)) % col_alignment
@@ -1873,10 +1876,10 @@ class W4A16NVFP4LinearMethod(NVFP4LinearMethod):
18731876
its fused path is SM>=100-gated upstream.
18741877
"""
18751878

1876-
def post_load_weights(self, module: Linear):
1879+
def transform_weights(self, module: Linear) -> None:
18771880
# Skip parent's 32x16 weight padding (apply() accepts [N, K/2] as-is)
18781881
# and un-swizzle per-block scale once at load.
1879-
LinearMethodBase.post_load_weights(self, module)
1882+
LinearMethodBase.transform_weights(self, module)
18801883
pad_rows = fp4_utils.pad_up(module.out_features, 128)
18811884
pad_cols = fp4_utils.pad_up(
18821885
module.in_features // module.scaling_vector_size, 4)
@@ -2914,6 +2917,7 @@ def __init__(
29142917
dtype=self.dtype) if reduce_output else None
29152918

29162919
self._weights_created = False
2920+
self._weights_transformed = False
29172921
self.reduce_output = reduce_output
29182922
self.use_custom_cublas_mm = use_custom_cublas_mm
29192923
self.use_cute_dsl_bf16_gemm = use_cute_dsl_bf16_gemm
@@ -2966,6 +2970,7 @@ def create_weights(self):
29662970
self.dtype)
29672971

29682972
self._weights_created = True
2973+
self._weights_transformed = False
29692974

29702975
@property
29712976
def has_any_quant(self):
@@ -3127,6 +3132,7 @@ def load_weights(self,
31273132
assert allow_partial_loading is False, (
31283133
f"{type(self.quant_method).__name__} does not support "
31293134
"allow_partial_loading")
3135+
self._weights_transformed = False
31303136
self.quant_method.load_weights(
31313137
self,
31323138
weights,
@@ -3136,8 +3142,14 @@ def load_weights(self,
31363142
def process_weights_after_loading(self):
31373143
self.quant_method.process_weights_after_loading(self)
31383144

3139-
def post_load_weights(self):
3140-
self.quant_method.post_load_weights(self)
3145+
def transform_weights(self) -> None:
3146+
if self._weights_transformed:
3147+
return
3148+
self.quant_method.transform_weights(self)
3149+
self._weights_transformed = True
3150+
3151+
def post_load_weights(self) -> None:
3152+
self.transform_weights()
31413153

31423154
def pre_reload_weights(self):
31433155
assert hasattr(

tests/unittest/_torch/pyexecutor/test_model_loader_gms.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,13 @@ def test_gms_load_branch(monkeypatch, is_rw, expected_events):
186186
# ``model=model`` is passed for symmetry with the LoadFormat.AUTO
187187
# path (see model_loader.py); HF ignores it, MX uses it for direct
188188
# P2P writes when MX+GMS composition eventually lands.
189+
# ``source_identity`` is included so format-specific loaders can
190+
# publish the same compatibility fingerprint the RO path validates.
189191
checkpoint_loader.load_weights.assert_called_once_with(
190-
"/ckpt", mapping=loader.mapping, model=model
192+
"/ckpt",
193+
mapping=loader.mapping,
194+
model=model,
195+
source_identity=loader._source_identity,
191196
)
192197
loader._call_load_weights.assert_called_once()
193198
backend.move_untracked_params.assert_called_once_with(model)

tests/unittest/_torch/pyexecutor/test_model_loader_mx.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
import torch
1010
from torch import nn
1111

12+
from tensorrt_llm._torch.modules import attention as attention_mod
13+
from tensorrt_llm._torch.modules.attention import MLA
14+
from tensorrt_llm._torch.modules.linear import Linear
1215
from tensorrt_llm._torch.pyexecutor import model_loader as model_loader_mod
1316
from tensorrt_llm._torch.pyexecutor.model_loader import ModelLoader
1417
from tensorrt_llm.llmapi.llm_args import LoadFormat
@@ -256,3 +259,56 @@ def test_reset_weights_transformed_only_resets_existing_flags():
256259
assert model.child._weights_transformed is False
257260
assert model.transformed_child._weights_transformed is False
258261
assert not hasattr(model.removed_child, "_weights_transformed")
262+
263+
264+
def test_linear_transform_weights_is_idempotent():
265+
linear = Linear(
266+
1,
267+
1,
268+
bias=False,
269+
reduce_output=False,
270+
skip_create_weights_in_init=True,
271+
)
272+
linear.quant_method = MagicMock()
273+
274+
linear.transform_weights()
275+
linear.post_load_weights()
276+
277+
linear.quant_method.transform_weights.assert_called_once_with(linear)
278+
assert linear._weights_transformed is True
279+
280+
linear._weights_transformed = False
281+
linear.post_load_weights()
282+
assert linear.quant_method.transform_weights.call_count == 2
283+
284+
285+
def test_mla_transform_weights_is_idempotent(monkeypatch):
286+
monkeypatch.setattr(attention_mod, "get_sm_version", lambda: 120)
287+
quant_mode = SimpleNamespace(has_fp8_block_scales=lambda: True)
288+
mla = MLA.__new__(MLA)
289+
mla._weights_transformed = False
290+
mla.kv_b_proj = SimpleNamespace(quant_config=SimpleNamespace(quant_mode=quant_mode))
291+
mla.k_b_proj_trans = "k_weight"
292+
mla.k_b_proj_trans_scale = "k_scale"
293+
mla.v_b_proj = "v_weight"
294+
mla.v_b_proj_scale = "v_scale"
295+
calls = []
296+
297+
def fake_resmooth(weight, scale, recipe):
298+
calls.append((weight, scale, recipe))
299+
return f"{weight}_transformed", f"{scale}_transformed"
300+
301+
mla.resmooth_parameters = fake_resmooth
302+
303+
MLA.transform_weights(mla)
304+
MLA.post_load_weights(mla)
305+
306+
assert calls == [
307+
("k_weight", "k_scale", (1, 128, 128)),
308+
("v_weight", "v_scale", (1, 128, 128)),
309+
]
310+
assert mla.k_b_proj_trans == "k_weight_transformed"
311+
assert mla.k_b_proj_trans_scale == "k_scale_transformed"
312+
assert mla.v_b_proj == "v_weight_transformed"
313+
assert mla.v_b_proj_scale == "v_scale_transformed"
314+
assert mla._weights_transformed is True

0 commit comments

Comments
 (0)