Skip to content
Merged
1 change: 1 addition & 0 deletions .github/workflows/slow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ env:
PYTHON_VERSION: "3.9"
HATCH_VERSION: "1.14.1"
HAYSTACK_MPS_ENABLED: false
HAYSTACK_XPU_ENABLED: false

on:
workflow_dispatch: # Activate this workflow manually
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,9 @@ jobs:
needs: unit-tests
runs-on: windows-latest
timeout-minutes: 30
env:
HAYSTACK_XPU_ENABLED: false

steps:
- uses: actions/checkout@v4

Expand Down
22 changes: 21 additions & 1 deletion haystack/utils/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class DeviceType(Enum):
GPU = "cuda"
DISK = "disk"
MPS = "mps"
XPU = "xpu"

def __str__(self):
return self.value
Expand Down Expand Up @@ -126,6 +127,16 @@ def mps() -> "Device":
"""
return Device(DeviceType.MPS)

@staticmethod
def xpu() -> "Device":
"""
Create a generic Intel GPU Optimization device.

:returns:
The XPU device.
"""
return Device(DeviceType.XPU)

@staticmethod
def from_str(string: str) -> "Device":
"""
Expand Down Expand Up @@ -482,7 +493,7 @@ def _get_default_device() -> Device:
Return the default device for Haystack.

Precedence:
GPU > MPS > CPU. If PyTorch is not installed, only CPU is available.
GPU > XPU > MPS > CPU. If PyTorch is not installed, only CPU is available.

:returns:
The default device.
Expand All @@ -496,12 +507,21 @@ def _get_default_device() -> Device:
and os.getenv("HAYSTACK_MPS_ENABLED", "true") != "false"
)
has_cuda = torch.cuda.is_available()
has_xpu = (
hasattr(torch, "xpu")
and hasattr(torch.xpu, "is_available")
and torch.xpu.is_available()
and os.getenv("HAYSTACK_XPU_ENABLED", "true") != "false"
)
except ImportError:
has_mps = False
has_cuda = False
has_xpu = False

if has_cuda:
return Device.gpu()
elif has_xpu:
return Device.xpu()
elif has_mps:
return Device.mps()
else:
Expand Down
19 changes: 18 additions & 1 deletion test/utils/test_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@ def test_device_creation():
assert Device.cpu().type == DeviceType.CPU
assert Device.gpu().type == DeviceType.GPU
assert Device.mps().type == DeviceType.MPS
assert Device.xpu().type == DeviceType.XPU
assert Device.disk().type == DeviceType.DISK

assert Device.from_str("cpu") == Device.cpu()
assert Device.from_str("cuda:1") == Device.gpu(1)
assert Device.from_str("disk") == Device.disk()
assert Device.from_str("mps:0") == Device(DeviceType.MPS, 0)
assert Device.from_str("xpu:0") == Device(DeviceType.XPU, 0)

with pytest.raises(ValueError, match="Device id must be >= 0"):
Device.gpu(-1)
Expand Down Expand Up @@ -115,23 +117,38 @@ def test_component_device_multiple():
assert multiple.first_device == ComponentDevice.from_single(Device.cpu())


@patch("torch.xpu.is_available")
@patch("torch.backends.mps.is_available")
@patch("torch.cuda.is_available")
def test_component_device_resolution(torch_cuda_is_available, torch_backends_mps_is_available):
def test_component_device_resolution(torch_cuda_is_available, torch_backends_mps_is_available, torch_xpu_is_available):
assert ComponentDevice.resolve_device(ComponentDevice.from_single(Device.cpu()))._single_device == Device.cpu()

torch_cuda_is_available.return_value = True
assert ComponentDevice.resolve_device(None)._single_device == Device.gpu(0)

torch_cuda_is_available.return_value = False
torch_xpu_is_available.return_value = True
torch_backends_mps_is_available.return_value = False
assert ComponentDevice.resolve_device(None)._single_device == Device.xpu()

torch_cuda_is_available.return_value = False
torch_xpu_is_available.return_value = False
torch_backends_mps_is_available.return_value = True
assert ComponentDevice.resolve_device(None)._single_device == Device.mps()

torch_cuda_is_available.return_value = False
torch_xpu_is_available.return_value = False
torch_backends_mps_is_available.return_value = False
assert ComponentDevice.resolve_device(None)._single_device == Device.cpu()

torch_cuda_is_available.return_value = False
torch_xpu_is_available.return_value = False
torch_backends_mps_is_available.return_value = True
os.environ["HAYSTACK_MPS_ENABLED"] = "false"
assert ComponentDevice.resolve_device(None)._single_device == Device.cpu()

torch_cuda_is_available.return_value = False
torch_xpu_is_available.return_value = True
os.environ["HAYSTACK_XPU_ENABLED"] = "false"
torch_backends_mps_is_available.return_value = False
assert ComponentDevice.resolve_device(None)._single_device == Device.cpu()
Loading