Implement _dequantize for TorchAO quantizer#13538
Conversation
|
Hi @sayakpaul . Would you please review this PR? Thanks! |
| assert isinstance(module.weight, TorchAOBaseTensor), ( | ||
| f"Layer {name} weight is {type(module.weight)}, expected TorchAOBaseTensor" | ||
| ) |
There was a problem hiding this comment.
Can we also enable dequantization tests for TorchAO tester mixin?
- Add _dequantize() method in TorchAoHfQuantizer that dequantizes TorchAOBaseTensor weights back to standard nn.Parameter - Fix _verify_if_layer_quantized to check isinstance(weight, TorchAOBaseTensor) so dequantized layers are correctly detected as non-quantized
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
a819214 to
9102fb8
Compare
| from torchao.utils import TorchAOBaseTensor | ||
|
|
||
| for name, module in model.named_modules(): | ||
| if isinstance(module, nn.Linear) and isinstance(module.weight, TorchAOBaseTensor): |
There was a problem hiding this comment.
TorchAOBaseTensor does not expose dequantize as a public API, it is defined on child classes. I agree that it would make sense to do so in the future. If you want to be safe here it might be better to check for individual tensor subclasses that do expose it.
There was a problem hiding this comment.
Thanks for the review @vkuzo! You're right that dequantize() is defined on child classes rather than on TorchAOBaseTensor itself. I've added a hasattr guard so we safely skip any subclass that doesn't expose it. In practice all quantized tensor subclasses we encounter do implement dequantize(), but this makes it future-proof.
Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
95d0118 to
df36f1a
Compare
| ], | ||
| ids=["int4wo", "int8wo", "int8dq"], | ||
| ) | ||
| def test_torchao_dequantize(self, quant_type): |
There was a problem hiding this comment.
I ran the tests with the following command: ``
And there are test failures:
FAILED tests/models/transformers/test_models_transformer_flux.py::TestFluxTransformerTorchAo::test_torchao_dequantize[int4wo] - NotImplementedError: Int4Tensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload...
FAILED tests/models/transformers/test_models_transformer_flux.py::TestFluxTransformerTorchAo::test_torchao_dequantize[int8wo] - RuntimeError: mat1 and mat2 must have the same dtype, but got Float and BFloat16
FAILED tests/models/transformers/test_models_transformer_flux.py::TestFluxTransformerTorchAo::test_torchao_dequantize[int8dq] - NotImplementedError: LinearActivationQuantizedTensor dispatch: attempting to run unimplemented operator/funct...With the following diff I managed to get it down to two:
diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py
index 840eaa338..e73c31561 100644
--- a/tests/models/transformers/test_models_transformer_flux.py
+++ b/tests/models/transformers/test_models_transformer_flux.py
@@ -367,6 +367,10 @@ class TestFluxTransformerQuanto(FluxTransformerTesterConfig, QuantoTesterMixin):
class TestFluxTransformerTorchAo(FluxTransformerTesterConfig, TorchAoTesterMixin):
"""TorchAO quantization tests for Flux Transformer."""
+ @property
+ def torch_dtype(self):
+ return torch.bfloat16
+
class TestFluxTransformerGGUF(FluxTransformerTesterConfig, GGUFTesterMixin):
@propertyFAILED tests/models/transformers/test_models_transformer_flux.py::TestFluxTransformerTorchAo::test_torchao_dequantize[int4wo] - NotImplementedError: Int4Tensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload...
FAILED tests/models/transformers/test_models_transformer_flux.py::TestFluxTransformerTorchAo::test_torchao_dequantize[int8dq] - NotImplementedError: LinearActivationQuantizedTensor dispatch: attempting to run unimplemented operator/funct...I am on an H100.
2cbe719 to
83431bf
Compare
83431bf to
450d0e4
Compare
|
Hi @sayakpaul . I have fixed the dtype issue and skip |
What does this PR do?
Implements the
_dequantize()method forTorchAoHfQuantizer, enablingmodel.dequantize()to convert TorchAO-quantized models back to standard float weights.Changes
Add
_dequantize()method: Iterates allnn.Linearmodules, callsweight.dequantize()onTorchAOBaseTensorweights, replaces them with standardnn.Parameter, and resets any overriddenextra_repr.Fix
_verify_if_layer_quantized: Addedisinstance(module.weight, TorchAOBaseTensor)check so that dequantized layers (which are stillnn.Linearbut with plain tensor weights) are correctly detected as non-quantized.