Skip to content

Commit 71027a6

Browse files
committed
[TRTLLM-13247][fix] Address CodeRabbit review comments
Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
1 parent 896e764 commit 71027a6

3 files changed

Lines changed: 15 additions & 1 deletion

File tree

tensorrt_llm/_torch/modules/attention.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3046,5 +3046,8 @@ def transform_weights(self) -> None:
30463046
self.v_b_proj, self.v_b_proj_scale, recipe=(1, 128, 128))
30473047
self._weights_transformed = True
30483048

3049+
def cache_derived_state(self) -> None:
3050+
self._weights_transformed = True
3051+
30493052
def post_load_weights(self) -> None:
30503053
self.transform_weights()

tensorrt_llm/_torch/modules/linear.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ def load_weights(self,
381381
self.process_weights_after_loading(module)
382382

383383
def transform_weights(self, module: Linear) -> None:
384-
pass
384+
...
385385

386386
def post_load_weights(self, module: Linear) -> None:
387387
self.transform_weights(module)
@@ -3148,6 +3148,9 @@ def transform_weights(self) -> None:
31483148
self.quant_method.transform_weights(self)
31493149
self._weights_transformed = True
31503150

3151+
def cache_derived_state(self) -> None:
3152+
self._weights_transformed = True
3153+
31513154
def post_load_weights(self) -> None:
31523155
self.transform_weights()
31533156

tests/unittest/_torch/pyexecutor/test_model_loader_mx.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,10 @@ def test_linear_transform_weights_is_idempotent():
307307
linear.post_load_weights()
308308
assert linear.quant_method.transform_weights.call_count == 2
309309

310+
linear._weights_transformed = False
311+
linear.cache_derived_state()
312+
assert linear._weights_transformed is True
313+
310314

311315
def test_mla_transform_weights_is_idempotent(monkeypatch):
312316
monkeypatch.setattr(attention_mod, "get_sm_version", lambda: 120)
@@ -338,3 +342,7 @@ def fake_resmooth(weight, scale, recipe):
338342
assert mla.v_b_proj == "v_weight_transformed"
339343
assert mla.v_b_proj_scale == "v_scale_transformed"
340344
assert mla._weights_transformed is True
345+
346+
mla._weights_transformed = False
347+
MLA.cache_derived_state(mla)
348+
assert mla._weights_transformed is True

0 commit comments

Comments
 (0)