Skip to content

Commit 3e253b4

Browse files
authored
Merge branch 'main' into fix-bucket-batch-sampler-cache-alignment
2 parents 94e2c3d + 8ee10d8 commit 3e253b4

1 file changed

Lines changed: 8 additions & 2 deletions

File tree

tests/models/testing_utils/quantization.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -805,6 +805,10 @@ class TorchAoConfigMixin:
805805
@staticmethod
806806
def _get_quant_config(config_name):
807807
config_cls = getattr(_torchao_quantization, config_name)
808+
# TorchAO int4 quantization requires plain_int32 packing format on Intel XPU
809+
if config_name == "Int4WeightOnlyConfig" and torch_device == "xpu":
810+
return TorchAoConfig(config_cls(int4_packing_format="plain_int32"))
811+
808812
return TorchAoConfig(config_cls())
809813

810814
def _create_quantized_model(self, config_name, **extra_kwargs):
@@ -819,8 +823,10 @@ def _verify_if_layer_quantized(self, name, module, config_kwargs):
819823
assert isinstance(module, torch.nn.Linear), f"Layer {name} is not Linear, got {type(module)}"
820824

821825

822-
# int4wo requires CUDA-specific ops (_convert_weight_to_int4pack)
823-
_int4wo_skip = pytest.mark.skipif(torch_device != "cuda", reason="int4wo quantization requires CUDA")
826+
# int4wo requires CUDA or XPU ops (_convert_weight_to_int4pack)
827+
_int4wo_skip = pytest.mark.skipif(
828+
torch_device not in ["cuda", "xpu"], reason="int4wo quantization requires CUDA or XPU"
829+
)
824830

825831

826832
@is_torchao

0 commit comments

Comments
 (0)