@@ -818,6 +818,10 @@ class TorchAoConfigMixin:
818818 @staticmethod
819819 def _get_quant_config (config_name ):
820820 config_cls = getattr (_torchao_quantization , config_name )
821+ # TorchAO int4 quantization requires plain_int32 packing format on Intel XPU
822+ if config_name == "Int4WeightOnlyConfig" and torch_device == "xpu" :
823+ return TorchAoConfig (config_cls (int4_packing_format = "plain_int32" ))
824+
821825 return TorchAoConfig (config_cls ())
822826
823827 def _create_quantized_model (self , config_name , ** extra_kwargs ):
@@ -832,10 +836,6 @@ def _verify_if_layer_quantized(self, name, module, config_kwargs):
832836 assert isinstance (module , torch .nn .Linear ), f"Layer { name } is not Linear, got { type (module )} "
833837
834838
835- # int4wo requires CUDA-specific ops (_convert_weight_to_int4pack)
836- _int4wo_skip = pytest .mark .skipif (torch_device != "cuda" , reason = "int4wo quantization requires CUDA" )
837-
838-
839839@is_torchao
840840@require_accelerator
841841@require_torchao_version_greater_or_equal ("0.7.0" )
@@ -861,7 +861,7 @@ class TorchAoTesterMixin(TorchAoConfigMixin, QuantizationTesterMixin):
861861 @pytest .mark .parametrize (
862862 "quant_type" ,
863863 [
864- pytest . param ( "int4wo" , marks = _int4wo_skip ) ,
864+ "int4wo" ,
865865 "int8wo" ,
866866 "int8dq" ,
867867 ],
@@ -873,7 +873,7 @@ def test_torchao_quantization_num_parameters(self, quant_type):
873873 @pytest .mark .parametrize (
874874 "quant_type" ,
875875 [
876- pytest . param ( "int4wo" , marks = _int4wo_skip ) ,
876+ "int4wo" ,
877877 "int8wo" ,
878878 "int8dq" ,
879879 ],
@@ -888,7 +888,7 @@ def test_torchao_quantization_memory_footprint(self, quant_type):
888888 @pytest .mark .parametrize (
889889 "quant_type" ,
890890 [
891- pytest . param ( "int4wo" , marks = _int4wo_skip ) ,
891+ "int4wo" ,
892892 "int8wo" ,
893893 "int8dq" ,
894894 ],
0 commit comments