|
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | 15 |
|
| 16 | +import pickle |
| 17 | + |
16 | 18 | import pytest |
17 | 19 | import torch |
18 | 20 | import torch.nn as nn |
19 | 21 |
|
20 | 22 | import modelopt.torch.quantization as mtq |
21 | | -from modelopt.torch.quantization.nn import QuantLinearConvBase |
| 23 | +from modelopt.torch.quantization.nn import QuantLinearConvBase, TensorQuantizer |
22 | 24 |
|
23 | 25 | try: |
24 | 26 | from accelerate.hooks import ModelHook, add_hook_to_module |
@@ -51,3 +53,30 @@ def test_linear_with_accelerate_monkey_patched_forward(): |
51 | 53 |
|
52 | 54 | assert module_test.input_quantizer.amax is not None |
53 | 55 | assert module_test.weight_quantizer.amax is not None |
| 56 | + |
| 57 | + |
| 58 | +def test_tensor_quantizer_modelopt_state_with_accelerate_hook(): |
| 59 | + """Verify accelerate hook attributes are excluded from modelopt state. |
| 60 | +
|
| 61 | + When accelerate's add_hook_to_module patches a TensorQuantizer, it adds |
| 62 | + _hf_hook, _old_forward, and an instance-level forward (a functools.partial |
| 63 | + wrapping a local function). These must be excluded from the modelopt state |
| 64 | + dict, otherwise torch.save / pickle will fail with: |
| 65 | + AttributeError: Can't get local object 'add_hook_to_module.<locals>.new_forward' |
| 66 | + """ |
| 67 | + tq = TensorQuantizer() |
| 68 | + add_hook_to_module(tq, ModelHook()) |
| 69 | + |
| 70 | + # The hook should have injected these instance attributes |
| 71 | + assert hasattr(tq, "_hf_hook") |
| 72 | + assert hasattr(tq, "_old_forward") |
| 73 | + assert "forward" in tq.__dict__ |
| 74 | + |
| 75 | + # None of the accelerate attributes should appear in the modelopt state |
| 76 | + state = tq.get_modelopt_state() |
| 77 | + accelerate_attrs = {"_hf_hook", "_old_forward", "forward"} |
| 78 | + leaked = accelerate_attrs & state.keys() |
| 79 | + assert not leaked, f"Accelerate attributes leaked into modelopt state: {leaked}" |
| 80 | + |
| 81 | + # The state dict must be picklable (torch.save uses pickle internally) |
| 82 | + pickle.dumps(state) |
0 commit comments