Skip to content

Commit f2be8bd

Browse files
change minimum version guard for torchao to 0.15.0 (#13355)
1 parent 7da22b9 commit f2be8bd

File tree

3 files changed

+17
-20
lines changed

3 files changed

+17
-20
lines changed

src/diffusers/quantizers/quantization_config.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -470,8 +470,8 @@ def __init__(
470470
self.post_init()
471471

472472
def post_init(self):
473-
if is_torchao_version("<=", "0.9.0"):
474-
raise ValueError("TorchAoConfig requires torchao > 0.9.0. Please upgrade with `pip install -U torchao`.")
473+
if is_torchao_version("<", "0.15.0"):
474+
raise ValueError("TorchAoConfig requires torchao >= 0.15.0. Please upgrade with `pip install -U torchao`.")
475475

476476
from torchao.quantization.quant_api import AOBaseConfig
477477

@@ -495,8 +495,8 @@ def to_dict(self):
495495
@classmethod
496496
def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
497497
"""Create configuration from a dictionary."""
498-
if not is_torchao_version(">", "0.9.0"):
499-
raise NotImplementedError("TorchAoConfig requires torchao > 0.9.0 for construction from dict")
498+
if not is_torchao_version(">=", "0.15.0"):
499+
raise NotImplementedError("TorchAoConfig requires torchao >= 0.15.0 for construction from dict")
500500
config_dict = config_dict.copy()
501501
quant_type = config_dict.pop("quant_type")
502502

src/diffusers/quantizers/torchao/torchao_quantizer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def _update_torch_safe_globals():
113113
is_torch_available()
114114
and is_torch_version(">=", "2.6.0")
115115
and is_torchao_available()
116-
and is_torchao_version(">=", "0.7.0")
116+
and is_torchao_version(">=", "0.15.0")
117117
):
118118
_update_torch_safe_globals()
119119

@@ -168,10 +168,10 @@ def validate_environment(self, *args, **kwargs):
168168
raise ImportError(
169169
"Loading a TorchAO quantized model requires the torchao library. Please install with `pip install torchao`"
170170
)
171-
torchao_version = version.parse(importlib.metadata.version("torch"))
172-
if torchao_version < version.parse("0.7.0"):
171+
torchao_version = version.parse(importlib.metadata.version("torchao"))
172+
if torchao_version < version.parse("0.15.0"):
173173
raise RuntimeError(
174-
f"The minimum required version of `torchao` is 0.7.0, but the current version is {torchao_version}. Please upgrade with `pip install -U torchao`."
174+
f"The minimum required version of `torchao` is 0.15.0, but the current version is {torchao_version}. Please upgrade with `pip install -U torchao`."
175175
)
176176

177177
self.offload = False

tests/quantization/torchao/test_torchao.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,11 @@
1414
# limitations under the License.
1515

1616
import gc
17-
import importlib.metadata
1817
import tempfile
1918
import unittest
2019
from typing import List
2120

2221
import numpy as np
23-
from packaging import version
2422
from parameterized import parameterized
2523
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
2624

@@ -82,18 +80,17 @@ def _is_xpu_or_cuda_capability_atleast_8_9() -> bool:
8280
Float8WeightOnlyConfig,
8381
Int4WeightOnlyConfig,
8482
Int8DynamicActivationInt8WeightConfig,
83+
Int8DynamicActivationIntxWeightConfig,
8584
Int8WeightOnlyConfig,
85+
IntxWeightOnlyConfig,
8686
)
8787
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
8888
from torchao.utils import get_model_size_in_bytes
8989

90-
if version.parse(importlib.metadata.version("torchao")) >= version.Version("0.10.0"):
91-
from torchao.quantization import Int8DynamicActivationIntxWeightConfig, IntxWeightOnlyConfig
92-
9390

9491
@require_torch
9592
@require_torch_accelerator
96-
@require_torchao_version_greater_or_equal("0.14.0")
93+
@require_torchao_version_greater_or_equal("0.15.0")
9794
class TorchAoConfigTest(unittest.TestCase):
9895
def test_to_dict(self):
9996
"""
@@ -128,7 +125,7 @@ def test_repr(self):
128125
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
129126
@require_torch
130127
@require_torch_accelerator
131-
@require_torchao_version_greater_or_equal("0.14.0")
128+
@require_torchao_version_greater_or_equal("0.15.0")
132129
class TorchAoTest(unittest.TestCase):
133130
def tearDown(self):
134131
gc.collect()
@@ -527,7 +524,7 @@ def test_sequential_cpu_offload(self):
527524
inputs = self.get_dummy_inputs(torch_device)
528525
_ = pipe(**inputs)
529526

530-
@require_torchao_version_greater_or_equal("0.9.0")
527+
@require_torchao_version_greater_or_equal("0.15.0")
531528
def test_aobase_config(self):
532529
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
533530
components = self.get_dummy_components(quantization_config)
@@ -540,7 +537,7 @@ def test_aobase_config(self):
540537
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
541538
@require_torch
542539
@require_torch_accelerator
543-
@require_torchao_version_greater_or_equal("0.14.0")
540+
@require_torchao_version_greater_or_equal("0.15.0")
544541
class TorchAoSerializationTest(unittest.TestCase):
545542
model_name = "hf-internal-testing/tiny-flux-pipe"
546543

@@ -650,7 +647,7 @@ def test_aobase_config(self):
650647
self._check_serialization_expected_slice(quant_type, expected_slice, device)
651648

652649

653-
@require_torchao_version_greater_or_equal("0.14.0")
650+
@require_torchao_version_greater_or_equal("0.15.0")
654651
class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
655652
@property
656653
def quantization_config(self):
@@ -696,7 +693,7 @@ def test_torch_compile_with_group_offload_leaf(self, use_stream):
696693
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
697694
@require_torch
698695
@require_torch_accelerator
699-
@require_torchao_version_greater_or_equal("0.14.0")
696+
@require_torchao_version_greater_or_equal("0.15.0")
700697
@slow
701698
@nightly
702699
class SlowTorchAoTests(unittest.TestCase):
@@ -854,7 +851,7 @@ def test_memory_footprint_int8wo(self):
854851

855852
@require_torch
856853
@require_torch_accelerator
857-
@require_torchao_version_greater_or_equal("0.14.0")
854+
@require_torchao_version_greater_or_equal("0.15.0")
858855
@slow
859856
@nightly
860857
class SlowTorchAoPreserializedModelTests(unittest.TestCase):

0 commit comments

Comments
 (0)