Skip to content

Commit f64bddf

Browse files
committed
test: add xpu code device support
1 parent 474c013 commit f64bddf

1 file changed

Lines changed: 17 additions & 1 deletion

File tree

test/utils/test_device.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,14 @@ def test_device_creation():
2222
assert Device.cpu().type == DeviceType.CPU
2323
assert Device.gpu().type == DeviceType.GPU
2424
assert Device.mps().type == DeviceType.MPS
25+
assert Device.xpu().type == DeviceType.XPU
2526
assert Device.disk().type == DeviceType.DISK
2627

2728
assert Device.from_str("cpu") == Device.cpu()
2829
assert Device.from_str("cuda:1") == Device.gpu(1)
2930
assert Device.from_str("disk") == Device.disk()
3031
assert Device.from_str("mps:0") == Device(DeviceType.MPS, 0)
32+
assert Device.from_str("xpu:0") == Device(DeviceType.XPU, 0)
3133

3234
with pytest.raises(ValueError, match="Device id must be >= 0"):
3335
Device.gpu(-1)
@@ -115,23 +117,37 @@ def test_component_device_multiple():
115117
assert multiple.first_device == ComponentDevice.from_single(Device.cpu())
116118

117119

120+
@patch("torch.xpu.is_available")
118121
@patch("torch.backends.mps.is_available")
119122
@patch("torch.cuda.is_available")
120-
def test_component_device_resolution(torch_cuda_is_available, torch_backends_mps_is_available):
123+
def test_component_device_resolution(torch_cuda_is_available, torch_backends_mps_is_available, torch_xpu_is_available):
121124
assert ComponentDevice.resolve_device(ComponentDevice.from_single(Device.cpu()))._single_device == Device.cpu()
122125

123126
torch_cuda_is_available.return_value = True
124127
assert ComponentDevice.resolve_device(None)._single_device == Device.gpu(0)
125128

126129
torch_cuda_is_available.return_value = False
130+
torch_xpu_is_available = True
131+
torch_backends_mps_is_available.return_value = False
132+
assert ComponentDevice.resolve_device(None)._single_device == Device.xpu()
133+
134+
torch_cuda_is_available.return_value = False
135+
torch_xpu_is_available = False
127136
torch_backends_mps_is_available.return_value = True
128137
assert ComponentDevice.resolve_device(None)._single_device == Device.mps()
129138

130139
torch_cuda_is_available.return_value = False
140+
torch_xpu_is_available = False
131141
torch_backends_mps_is_available.return_value = False
132142
assert ComponentDevice.resolve_device(None)._single_device == Device.cpu()
133143

134144
torch_cuda_is_available.return_value = False
135145
torch_backends_mps_is_available.return_value = True
136146
os.environ["HAYSTACK_MPS_ENABLED"] = "false"
137147
assert ComponentDevice.resolve_device(None)._single_device == Device.cpu()
148+
149+
torch_cuda_is_available.return_value = False
150+
torch_xpu_is_available = True
151+
torch_backends_mps_is_available.return_value = False
152+
os.environ["HAYSTACK_XPU_ENABLED"] = "false"
153+
assert ComponentDevice.resolve_device(None)._single_device == Device.cpu()

0 commit comments

Comments
 (0)