Skip to content

Commit 27a0fb6

Browse files
committed
[Bug fix] Fake quantized model save after HF accelerate hooks are added
Signed-off-by: realAsma <akuriparambi@nvidia.com>
1 parent 9e38041 commit 27a0fb6

2 files changed

Lines changed: 34 additions & 1 deletion

File tree

modelopt/torch/quantization/nn/modules/tensor_quantizer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,10 @@ class TensorQuantizer(nn.Module):
156156
"_padding",
157157
# Extra flags added by huggingface
158158
"_is_hf_initialized",
159+
# Extra flags added by accelerate
160+
"_hf_hook",
161+
"_old_forward",
162+
"forward",
159163
# Extra flags added by deepspeed
160164
"ds_external_parameters",
161165
"all_parameters",

tests/unit/torch/quantization/plugins/test_accelerate.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import pickle
17+
1618
import pytest
1719
import torch
1820
import torch.nn as nn
1921

2022
import modelopt.torch.quantization as mtq
21-
from modelopt.torch.quantization.nn import QuantLinearConvBase
23+
from modelopt.torch.quantization.nn import QuantLinearConvBase, TensorQuantizer
2224

2325
try:
2426
from accelerate.hooks import ModelHook, add_hook_to_module
@@ -51,3 +53,30 @@ def test_linear_with_accelerate_monkey_patched_forward():
5153

5254
assert module_test.input_quantizer.amax is not None
5355
assert module_test.weight_quantizer.amax is not None
56+
57+
58+
def test_tensor_quantizer_modelopt_state_with_accelerate_hook():
59+
"""Verify accelerate hook attributes are excluded from modelopt state.
60+
61+
When accelerate's add_hook_to_module patches a TensorQuantizer, it adds
62+
_hf_hook, _old_forward, and an instance-level forward (a functools.partial
63+
wrapping a local function). These must be excluded from the modelopt state
64+
dict, otherwise torch.save / pickle will fail with:
65+
AttributeError: Can't get local object 'add_hook_to_module.<locals>.new_forward'
66+
"""
67+
tq = TensorQuantizer()
68+
add_hook_to_module(tq, ModelHook())
69+
70+
# The hook should have injected these instance attributes
71+
assert hasattr(tq, "_hf_hook")
72+
assert hasattr(tq, "_old_forward")
73+
assert "forward" in tq.__dict__
74+
75+
# None of the accelerate attributes should appear in the modelopt state
76+
state = tq.get_modelopt_state()
77+
accelerate_attrs = {"_hf_hook", "_old_forward", "forward"}
78+
leaked = accelerate_attrs & state.keys()
79+
assert not leaked, f"Accelerate attributes leaked into modelopt state: {leaked}"
80+
81+
# The state dict must be picklable (torch.save uses pickle internally)
82+
pickle.dumps(state)

0 commit comments

Comments
 (0)