Skip to content
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions tests/models/testing_utils/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,6 +805,10 @@ class TorchAoConfigMixin:
@staticmethod
def _get_quant_config(config_name):
config_cls = getattr(_torchao_quantization, config_name)
# TorchAO int4 quantization requires plain_int32 packing format on Intel XPU
if config_name == "Int4WeightOnlyConfig" and torch_device == "xpu":
return TorchAoConfig(config_cls(int4_packing_format="plain_int32"))

return TorchAoConfig(config_cls())

def _create_quantized_model(self, config_name, **extra_kwargs):
Expand All @@ -819,8 +823,8 @@ def _verify_if_layer_quantized(self, name, module, config_kwargs):
assert isinstance(module, torch.nn.Linear), f"Layer {name} is not Linear, got {type(module)}"


# int4wo requires CUDA-specific ops (_convert_weight_to_int4pack)
_int4wo_skip = pytest.mark.skipif(torch_device != "cuda", reason="int4wo quantization requires CUDA")
Comment thread
jiqing-feng marked this conversation as resolved.
# int4wo requires CUDA or XPU ops (_convert_weight_to_int4pack)
_int4wo_skip = pytest.mark.skipif(torch_device not in ["cuda", "xpu"], reason="int4wo quantization requires CUDA or XPU")


@is_torchao
Expand Down
Loading