Skip to content

Commit c974090

Browse files
committed
minor
Signed-off-by: realAsma <akuriparambi@nvidia.com>
1 parent d0dfae0 commit c974090

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

tests/_test_utils/torch/quantization/quantize_common.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,15 @@ def save_restore_test(
9696
model_ref.load_state_dict(model_quant.state_dict())
9797
assert torch.allclose(model_quant(calib_data[0]), model_ref(calib_data[0]))
9898

99+
# Verify that TensorQuantizer subclass types are preserved after restore
100+
for name_q, mod_q in model_quant.named_modules():
101+
if name_q.endswith("quantizer"):
102+
mod_r = dict(model_ref.named_modules())[name_q]
103+
assert type(mod_q) is type(mod_r), (
104+
f"Quantizer class mismatch for '{name_q}': "
105+
f"expected {type(mod_q).__name__}, got {type(mod_r).__name__}"
106+
)
107+
99108
if version is not None and Version(version) < Version("0.29"):
100109
# Rest of the tests are not needed for version < 0.29
101110
return

0 commit comments

Comments
 (0)