File tree Expand file tree Collapse file tree 1 file changed +9
-0
lines changed
tests/_test_utils/torch/quantization Expand file tree Collapse file tree 1 file changed +9
-0
lines changed Original file line number Diff line number Diff line change @@ -96,6 +96,15 @@ def save_restore_test(
9696 model_ref .load_state_dict (model_quant .state_dict ())
9797 assert torch .allclose (model_quant (calib_data [0 ]), model_ref (calib_data [0 ]))
9898
99+ # Verify that TensorQuantizer subclass types are preserved after restore
100+ for name_q , mod_q in model_quant .named_modules ():
101+ if name_q .endswith ("quantizer" ):
102+ mod_r = dict (model_ref .named_modules ())[name_q ]
103+ assert type (mod_q ) is type (mod_r ), (
104+ f"Quantizer class mismatch for '{ name_q } ': "
105+ f"expected { type (mod_q ).__name__ } , got { type (mod_r ).__name__ } "
106+ )
107+
99108 if version is not None and Version (version ) < Version ("0.29" ):
100109 # Rest of the tests are not needed for version < 0.29
101110 return
You can’t perform that action at this time.
0 commit comments