Skip to content

Commit 1c4b424

Browse files
committed
Enable TorchAO int4wo quantization tests on XPU
- Remove _int4wo_skip marker that restricted int4wo tests to CUDA only - Add XPU-specific int4_packing_format='plain_int32' for Int4WeightOnlyConfig
1 parent c8c8401 commit 1c4b424

1 file changed

Lines changed: 7 additions & 7 deletions

File tree

tests/models/testing_utils/quantization.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -818,6 +818,10 @@ class TorchAoConfigMixin:
818818
@staticmethod
819819
def _get_quant_config(config_name):
820820
config_cls = getattr(_torchao_quantization, config_name)
821+
# TorchAO int4 quantization requires plain_int32 packing format on Intel XPU
822+
if config_name == "Int4WeightOnlyConfig" and torch_device == "xpu":
823+
return TorchAoConfig(config_cls(int4_packing_format="plain_int32"))
824+
821825
return TorchAoConfig(config_cls())
822826

823827
def _create_quantized_model(self, config_name, **extra_kwargs):
@@ -832,10 +836,6 @@ def _verify_if_layer_quantized(self, name, module, config_kwargs):
832836
assert isinstance(module, torch.nn.Linear), f"Layer {name} is not Linear, got {type(module)}"
833837

834838

835-
# int4wo requires CUDA-specific ops (_convert_weight_to_int4pack)
836-
_int4wo_skip = pytest.mark.skipif(torch_device != "cuda", reason="int4wo quantization requires CUDA")
837-
838-
839839
@is_torchao
840840
@require_accelerator
841841
@require_torchao_version_greater_or_equal("0.7.0")
@@ -861,7 +861,7 @@ class TorchAoTesterMixin(TorchAoConfigMixin, QuantizationTesterMixin):
861861
@pytest.mark.parametrize(
862862
"quant_type",
863863
[
864-
pytest.param("int4wo", marks=_int4wo_skip),
864+
"int4wo",
865865
"int8wo",
866866
"int8dq",
867867
],
@@ -873,7 +873,7 @@ def test_torchao_quantization_num_parameters(self, quant_type):
873873
@pytest.mark.parametrize(
874874
"quant_type",
875875
[
876-
pytest.param("int4wo", marks=_int4wo_skip),
876+
"int4wo",
877877
"int8wo",
878878
"int8dq",
879879
],
@@ -888,7 +888,7 @@ def test_torchao_quantization_memory_footprint(self, quant_type):
888888
@pytest.mark.parametrize(
889889
"quant_type",
890890
[
891-
pytest.param("int4wo", marks=_int4wo_skip),
891+
"int4wo",
892892
"int8wo",
893893
"int8dq",
894894
],

0 commit comments

Comments
 (0)