Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/diffusers/quantizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def dequantize(self, model):

# Delete quantizer and quantization config
del model.hf_quantizer
model.is_quantized = False

return model

Expand Down
16 changes: 16 additions & 0 deletions src/diffusers/quantizers/torchao/torchao_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,3 +376,19 @@ def is_trainable(self):
@property
def is_compileable(self) -> bool:
return True

def _dequantize(self, model):
from torchao.utils import TorchAOBaseTensor

for name, module in model.named_modules():
if isinstance(module, nn.Linear) and isinstance(module.weight, TorchAOBaseTensor):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TorchAOBaseTensor does not expose dequantize as a public API, it is defined on child classes. I agree that it would make sense to do so in the future. If you want to be safe here it might be better to check for individual tensor subclasses that do expose it.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review @vkuzo! You're right that dequantize() is defined on child classes rather than on TorchAOBaseTensor itself. I've added a hasattr guard so we safely skip any subclass that doesn't expose it. In practice all quantized tensor subclasses we encounter do implement dequantize(), but this makes it future-proof.

if not hasattr(module.weight, "dequantize"):
continue
device = module.weight.device
dequantized_weight = module.weight.dequantize().to(device)
module.weight = nn.Parameter(dequantized_weight)
# Reset extra_repr if it was overridden
if hasattr(module.extra_repr, "__func__") and module.extra_repr.__func__ is not nn.Linear.extra_repr:
module.extra_repr = types.MethodType(nn.Linear.extra_repr, module)

return model
5 changes: 5 additions & 0 deletions tests/models/testing_utils/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,7 +820,12 @@ def _create_quantized_model(self, config_name, **extra_kwargs):
return self.model_class.from_pretrained(self.pretrained_model_name_or_path, **kwargs)

def _verify_if_layer_quantized(self, name, module, config_kwargs):
from torchao.utils import TorchAOBaseTensor

assert isinstance(module, torch.nn.Linear), f"Layer {name} is not Linear, got {type(module)}"
assert isinstance(module.weight, TorchAOBaseTensor), (
f"Layer {name} weight is {type(module.weight)}, expected TorchAOBaseTensor"
)
Comment on lines +826 to +828
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also enable dequantization tests for TorchAO tester mixin?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.



# int4wo requires CUDA or XPU ops (_convert_weight_to_int4pack)
Expand Down
8 changes: 8 additions & 0 deletions tests/models/transformers/test_models_transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,10 @@ def pretrained_model_kwargs(self):
class TestFluxTransformerTorchAo(FluxTransformerTesterConfig, TorchAoTesterMixin):
"""TorchAO quantization tests for Flux Transformer."""

@property
def torch_dtype(self):
return torch.bfloat16


class TestFluxTransformerGGUF(FluxTransformerTesterConfig, GGUFTesterMixin):
@property
Expand Down Expand Up @@ -403,6 +407,10 @@ class TestFluxTransformerQuantoCompile(FluxTransformerTesterConfig, QuantoCompil
class TestFluxTransformerTorchAoCompile(FluxTransformerTesterConfig, TorchAoCompileTesterMixin):
"""TorchAO + compile tests for Flux Transformer."""

@property
def torch_dtype(self):
return torch.bfloat16


class TestFluxTransformerGGUFCompile(FluxTransformerTesterConfig, GGUFCompileTesterMixin):
@property
Expand Down
Loading