diff --git a/qubes/tests/devices_pci.py b/qubes/tests/devices_pci.py index 10704c4cd..18d783fa8 100644 --- a/qubes/tests/devices_pci.py +++ b/qubes/tests/devices_pci.py @@ -24,7 +24,7 @@ import qubes.tests import qubes.ext.pci from qubes.device_protocol import DeviceInterface -from qubes.utils import sbdf_to_path, path_to_sbdf +from qubes.utils import sbdf_to_path, path_to_sbdf, is_pci_path orig_open = open @@ -157,7 +157,17 @@ def test_011_path_to_sbdf2(self): path = path_to_sbdf("0000_00_18.4") self.assertEqual(path, "0000:00:18.4") + def test_020_is_pci_path(self): + self.assertTrue(is_pci_path("0000_00_18.4")) + def test_021_is_pci_path_false(self): + self.assertFalse(is_pci_path("0000_c6_00.0")) + + def test_022_is_pci_path_non_00_bus(self): + self.assertTrue(is_pci_path("0000_c0_00.0")) + + +@mock.patch("qubes.utils.SYSFS_BASE", tests_sysfs_path) class TC_10_PCI(qubes.tests.QubesTestCase): def setUp(self): super().setUp() diff --git a/qubes/utils.py b/qubes/utils.py index 1ddc53798..e88d10fef 100644 --- a/qubes/utils.py +++ b/qubes/utils.py @@ -373,6 +373,12 @@ def sbdf_to_path(device_id: str): ) sysfs_pci_devs_base = f"{SYSFS_BASE}/bus/pci/devices" + root_buses = [ + dev[3:] + for dev in os.listdir(f"{SYSFS_BASE}/devices") + if dev.startswith("pci") + ] + dev_match = regex.match(device_id) if not dev_match: raise ValueError("Invalid device identifier: {!r}".format(device_id)) @@ -380,7 +386,7 @@ def sbdf_to_path(device_id: str): segment = dev_match["segment"] else: segment = "0000" - if dev_match["bus"] == "00": + if f"{segment}:{dev_match['bus']}" in root_buses: return (f"{segment}_" if segment != "0000" else "") + ( f"{dev_match['bus']}_" f"{dev_match['device']}.{dev_match['function']}" @@ -488,8 +494,17 @@ def is_pci_path(device_id: str): :param device_id: device id to check :return: """ + + root_buses = [ + dev[3:].replace(":", "_") + for dev in os.listdir(f"{SYSFS_BASE}/devices") + if dev.startswith("pci") + ] + # add segment prefix for easier matching + if len(device_id) > 2 and device_id[2] == "_": + device_id = "0000_" + device_id path_re = re.compile( - r"\A([0-9a-f]{4}_)?00_[0-9a-f]{2}\.[0-9a-f]" + r"\A(" + "|".join(root_buses) + r")_[0-9a-f]{2}\.[0-9a-f]" r"(-[0-9a-f]{2}_[0-9a-f]{2}\.[0-9a-f])*\Z" ) return bool(path_re.match(device_id))