diff --git a/tests/models/testing_utils/quantization.py b/tests/models/testing_utils/quantization.py index 1aab0b240148..30d44a92c425 100644 --- a/tests/models/testing_utils/quantization.py +++ b/tests/models/testing_utils/quantization.py @@ -805,6 +805,10 @@ class TorchAoConfigMixin: @staticmethod def _get_quant_config(config_name): config_cls = getattr(_torchao_quantization, config_name) + # TorchAO int4 quantization requires plain_int32 packing format on Intel XPU + if config_name == "Int4WeightOnlyConfig" and torch_device == "xpu": + return TorchAoConfig(config_cls(int4_packing_format="plain_int32")) + return TorchAoConfig(config_cls()) def _create_quantized_model(self, config_name, **extra_kwargs): @@ -819,8 +823,10 @@ def _verify_if_layer_quantized(self, name, module, config_kwargs): assert isinstance(module, torch.nn.Linear), f"Layer {name} is not Linear, got {type(module)}" -# int4wo requires CUDA-specific ops (_convert_weight_to_int4pack) -_int4wo_skip = pytest.mark.skipif(torch_device != "cuda", reason="int4wo quantization requires CUDA") +# int4wo requires CUDA or XPU ops (_convert_weight_to_int4pack) +_int4wo_skip = pytest.mark.skipif( + torch_device not in ["cuda", "xpu"], reason="int4wo quantization requires CUDA or XPU" +) @is_torchao