diff --git a/src/promptfoo/environment.py b/src/promptfoo/environment.py index 5dd8e91..3159ee3 100644 --- a/src/promptfoo/environment.py +++ b/src/promptfoo/environment.py @@ -33,6 +33,26 @@ class Environment: has_sudo: bool = False # Best guess if user has sudo access +def _read_probe_file(path: Path) -> Optional[str]: + """ + Read an optional environment probe file. + + Returns: + File contents, or None when the probe file does not exist or cannot be read. + """ + if not path.exists(): + return None + + try: + with open(path) as f: + return f.read() + except OSError: + # Environment detection is best-effort. Proc/sys metadata files can be + # unreadable or disappear between exists() and open(), so treat that as + # "signal unavailable" and continue with fallback probes. + return None + + def _detect_linux_distro() -> tuple[Optional[str], Optional[str]]: """ Detect Linux distribution and version. @@ -48,47 +68,46 @@ def _detect_linux_distro() -> tuple[Optional[str], Optional[str]]: # Try /etc/os-release first, then /usr/lib/os-release (per freedesktop spec) for os_release_path in [Path("/etc/os-release"), Path("/usr/lib/os-release")]: - if os_release_path.exists(): - try: - with open(os_release_path) as f: - os_release = {} - for line in f: - line = line.strip() - if not line or line.startswith("#"): - continue - if "=" in line: - key, _, value = line.partition("=") - # Remove quotes - value = value.strip('"').strip("'") - os_release[key] = value - - distro_id = os_release.get("ID", "").lower() - version = os_release.get("VERSION_ID", "") - id_like = os_release.get("ID_LIKE", "").lower().split() - - # Normalize distro IDs - if distro_id in known_base_distros: - return distro_id, version - elif distro_id in rhel_family: - # Oracle Linux (ol), Amazon Linux (amzn) - return "rhel", version - elif distro_id in suse_family: - return "suse", version - - # Check ID_LIKE for derivative distributions (e.g., Pop!_OS, Raspbian, Mint) - if id_like: - for parent in id_like: - if parent in known_base_distros: - return parent, version - elif parent in rhel_family: - return "rhel", version - elif parent in suse_family: - return "suse", version - - # Return the raw distro_id if we couldn't normalize it - return distro_id, version - except OSError: - pass + os_release_content = _read_probe_file(os_release_path) + if os_release_content is None: + continue + + os_release = {} + for line in os_release_content.splitlines(): + line = line.strip() + if not line or line.startswith("#"): + continue + if "=" in line: + key, _, value = line.partition("=") + # Remove quotes + value = value.strip('"').strip("'") + os_release[key] = value + + distro_id = os_release.get("ID", "").lower() + version = os_release.get("VERSION_ID", "") + id_like = os_release.get("ID_LIKE", "").lower().split() + + # Normalize distro IDs + if distro_id in known_base_distros: + return distro_id, version + elif distro_id in rhel_family: + # Oracle Linux (ol), Amazon Linux (amzn) + return "rhel", version + elif distro_id in suse_family: + return "suse", version + + # Check ID_LIKE for derivative distributions (e.g., Pop!_OS, Raspbian, Mint) + if id_like: + for parent in id_like: + if parent in known_base_distros: + return parent, version + elif parent in rhel_family: + return "rhel", version + elif parent in suse_family: + return "suse", version + + # Return the raw distro_id if we couldn't normalize it + return distro_id, version # Fallback: check for specific files if Path("/etc/debian_version").exists(): @@ -112,14 +131,9 @@ def _detect_cloud_provider() -> Optional[str]: """ # AWS detection # Check for EC2 metadata - if Path("/sys/hypervisor/uuid").exists(): - try: - with open("/sys/hypervisor/uuid") as f: - uuid = f.read().strip() - if uuid.startswith("ec2") or uuid.startswith("EC2"): - return "aws" - except OSError: - pass + uuid = _read_probe_file(Path("/sys/hypervisor/uuid")) + if uuid and uuid.strip().lower().startswith("ec2"): + return "aws" # Check AWS environment variables if os.getenv("AWS_EXECUTION_ENV") or os.getenv("AWS_REGION"): @@ -127,29 +141,21 @@ def _detect_cloud_provider() -> Optional[str]: # GCP detection # Check for GCP metadata - if Path("/sys/class/dmi/id/product_name").exists(): - try: - with open("/sys/class/dmi/id/product_name") as f: - product = f.read().strip() - if "Google" in product or "GCE" in product: - return "gcp" - except OSError: - pass + product = _read_probe_file(Path("/sys/class/dmi/id/product_name")) + if product: + product = product.strip() + if "Google" in product or "GCE" in product: + return "gcp" # Check GCP environment variables if os.getenv("GOOGLE_CLOUD_PROJECT") or os.getenv("GCP_PROJECT"): return "gcp" # Azure detection - if Path("/sys/class/dmi/id/sys_vendor").exists(): - try: - with open("/sys/class/dmi/id/sys_vendor") as f: - vendor = f.read().strip() - # Could be Azure or Hyper-V, check for Azure-specific - if "Microsoft Corporation" in vendor and Path("/var/lib/waagent").exists(): - return "azure" - except OSError: - pass + vendor = _read_probe_file(Path("/sys/class/dmi/id/sys_vendor")) + # Could be Azure or Hyper-V, check for Azure-specific + if vendor and "Microsoft Corporation" in vendor.strip() and Path("/var/lib/waagent").exists(): + return "azure" # Check Azure environment variables if os.getenv("AZURE_SUBSCRIPTION_ID") or os.getenv("WEBSITE_INSTANCE_ID"): @@ -173,14 +179,9 @@ def _detect_container() -> tuple[bool, bool]: is_docker = True # Also check cgroup - if Path("/proc/1/cgroup").exists(): - try: - with open("/proc/1/cgroup") as f: - cgroup_content = f.read() - if "docker" in cgroup_content or "containerd" in cgroup_content: - is_docker = True - except OSError: - pass + cgroup_content = _read_probe_file(Path("/proc/1/cgroup")) + if cgroup_content and ("docker" in cgroup_content or "containerd" in cgroup_content): + is_docker = True # Kubernetes detection if os.getenv("KUBERNETES_SERVICE_HOST"): @@ -201,14 +202,11 @@ def _detect_wsl() -> bool: return True # Check /proc/version for Microsoft/WSL signatures - if Path("/proc/version").exists(): - try: - with open("/proc/version") as f: - version_info = f.read().lower() - if "microsoft" in version_info or "wsl" in version_info: - return True - except OSError: - pass + version_info = _read_probe_file(Path("/proc/version")) + if version_info: + version_info = version_info.lower() + if "microsoft" in version_info or "wsl" in version_info: + return True # Check for Windows filesystem mounts (WSL mounts Windows drives at /mnt/) # This is less reliable but can catch WSL 1 diff --git a/tests/test_environment.py b/tests/test_environment.py index 95207a7..7749bf6 100644 --- a/tests/test_environment.py +++ b/tests/test_environment.py @@ -17,11 +17,29 @@ _detect_container, _detect_linux_distro, _detect_python_env, + _detect_wsl, _has_sudo_access, + _read_probe_file, detect_environment, ) +class TestProbeFileReads: + """Test best-effort probe file reads.""" + + def test_read_probe_file_returns_none_when_missing(self, tmp_path: Path) -> None: + """Missing probe files return None.""" + assert _read_probe_file(tmp_path / "missing") is None + + def test_read_probe_file_returns_none_when_unreadable(self, tmp_path: Path) -> None: + """Unreadable probe files return None instead of raising.""" + probe_file = tmp_path / "probe" + probe_file.write_text("value") + + with mock.patch("builtins.open", side_effect=OSError("permission denied")): + assert _read_probe_file(probe_file) is None + + class TestLinuxDistroDetection: """Test Linux distribution detection.""" @@ -128,6 +146,40 @@ def path_constructor(path_str: str) -> mock.Mock: assert distro == "ubuntu" assert version == "22.04" + def test_detect_linux_distro_skips_unreadable_os_release(self) -> None: + """Unreadable /etc/os-release falls back to /usr/lib/os-release.""" + etc_path = mock.Mock() + etc_path.exists.return_value = True + + usr_path = mock.Mock() + usr_path.exists.return_value = True + + def path_constructor(path_str: str) -> mock.Mock: + if path_str == "/etc/os-release": + return etc_path + elif path_str == "/usr/lib/os-release": + return usr_path + fallback_path = mock.Mock() + fallback_path.exists.return_value = False + return fallback_path + + usr_open = mock.mock_open(read_data='ID=ubuntu\nVERSION_ID="22.04"') + + def open_side_effect(path: mock.Mock) -> mock.MagicMock: + if path is etc_path: + raise OSError("permission denied") + if path is usr_path: + return usr_open() + raise AssertionError(f"unexpected probe path: {path!r}") + + with ( + mock.patch("promptfoo.environment.Path", side_effect=path_constructor), + mock.patch("builtins.open", side_effect=open_side_effect), + ): + distro, version = _detect_linux_distro() + assert distro == "ubuntu" + assert version == "22.04" + class TestCloudProviderDetection: """Test cloud provider detection.""" @@ -184,6 +236,20 @@ def test_no_cloud_provider_detected(self) -> None: provider = _detect_cloud_provider() assert provider is None + def test_detect_cloud_provider_ignores_unreadable_probe_files(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Unreadable cloud metadata files fall back to environment variables.""" + monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "my-project") + + path_mock = mock.Mock() + path_mock.exists.return_value = True + + with ( + mock.patch("promptfoo.environment.Path", return_value=path_mock), + mock.patch("builtins.open", side_effect=OSError("permission denied")), + ): + provider = _detect_cloud_provider() + assert provider == "gcp" + class TestContainerDetection: """Test container detection.""" @@ -204,6 +270,23 @@ def test_detect_container_returns_tuple(self) -> None: assert isinstance(is_docker, bool) assert isinstance(is_k8s, bool) + def test_detect_container_ignores_unreadable_cgroup(self) -> None: + """Unreadable cgroup metadata does not raise.""" + + def path_constructor(path_str: str) -> mock.Mock: + path_mock = mock.Mock() + path_mock.exists.return_value = path_str == "/proc/1/cgroup" + return path_mock + + with ( + mock.patch("promptfoo.environment.Path", side_effect=path_constructor), + mock.patch("builtins.open", side_effect=OSError("permission denied")), + mock.patch.dict(os.environ, {}, clear=True), + ): + is_docker, is_k8s = _detect_container() + assert is_docker is False + assert is_k8s is False + class TestWSLDetection: """Test WSL detection.""" @@ -212,28 +295,37 @@ def test_detect_wsl_from_env_var(self, monkeypatch: pytest.MonkeyPatch) -> None: """Detect WSL from WSL_DISTRO_NAME environment variable.""" monkeypatch.setenv("WSL_DISTRO_NAME", "Ubuntu") - from promptfoo.environment import _detect_wsl - assert _detect_wsl() is True def test_detect_wsl_from_interop_env(self, monkeypatch: pytest.MonkeyPatch) -> None: """Detect WSL from WSL_INTEROP environment variable.""" monkeypatch.setenv("WSL_INTEROP", "/run/WSL/123_interop") - from promptfoo.environment import _detect_wsl - assert _detect_wsl() is True def test_no_wsl_detected(self) -> None: """Return False when not in WSL.""" with mock.patch.dict(os.environ, {}, clear=True): - from promptfoo.environment import _detect_wsl - # This will return False unless we're actually in WSL # Just verify it returns a boolean result = _detect_wsl() assert isinstance(result, bool) + def test_detect_wsl_ignores_unreadable_proc_version(self) -> None: + """Unreadable /proc/version does not raise.""" + + def path_constructor(path_str: str) -> mock.Mock: + path_mock = mock.Mock() + path_mock.exists.return_value = path_str == "/proc/version" + return path_mock + + with ( + mock.patch("promptfoo.environment.Path", side_effect=path_constructor), + mock.patch("builtins.open", side_effect=OSError("permission denied")), + mock.patch.dict(os.environ, {}, clear=True), + ): + assert _detect_wsl() is False + class TestCIDetection: """Test CI/CD platform detection."""