Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 80 additions & 82 deletions src/promptfoo/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,26 @@ class Environment:
has_sudo: bool = False # Best guess if user has sudo access


def _read_probe_file(path: Path) -> Optional[str]:
"""
Read an optional environment probe file.

Returns:
File contents, or None when the probe file does not exist or cannot be read.
"""
if not path.exists():
return None

try:
with open(path) as f:
return f.read()
except OSError:
# Environment detection is best-effort. Proc/sys metadata files can be
# unreadable or disappear between exists() and open(), so treat that as
# "signal unavailable" and continue with fallback probes.
return None


def _detect_linux_distro() -> tuple[Optional[str], Optional[str]]:
"""
Detect Linux distribution and version.
Expand All @@ -48,47 +68,46 @@ def _detect_linux_distro() -> tuple[Optional[str], Optional[str]]:

# Try /etc/os-release first, then /usr/lib/os-release (per freedesktop spec)
for os_release_path in [Path("/etc/os-release"), Path("/usr/lib/os-release")]:
if os_release_path.exists():
try:
with open(os_release_path) as f:
os_release = {}
for line in f:
line = line.strip()
if not line or line.startswith("#"):
continue
if "=" in line:
key, _, value = line.partition("=")
# Remove quotes
value = value.strip('"').strip("'")
os_release[key] = value

distro_id = os_release.get("ID", "").lower()
version = os_release.get("VERSION_ID", "")
id_like = os_release.get("ID_LIKE", "").lower().split()

# Normalize distro IDs
if distro_id in known_base_distros:
return distro_id, version
elif distro_id in rhel_family:
# Oracle Linux (ol), Amazon Linux (amzn)
return "rhel", version
elif distro_id in suse_family:
return "suse", version

# Check ID_LIKE for derivative distributions (e.g., Pop!_OS, Raspbian, Mint)
if id_like:
for parent in id_like:
if parent in known_base_distros:
return parent, version
elif parent in rhel_family:
return "rhel", version
elif parent in suse_family:
return "suse", version

# Return the raw distro_id if we couldn't normalize it
return distro_id, version
except OSError:
pass
os_release_content = _read_probe_file(os_release_path)
if os_release_content is None:
continue

os_release = {}
for line in os_release_content.splitlines():
line = line.strip()
if not line or line.startswith("#"):
continue
if "=" in line:
key, _, value = line.partition("=")
# Remove quotes
value = value.strip('"').strip("'")
os_release[key] = value

distro_id = os_release.get("ID", "").lower()
version = os_release.get("VERSION_ID", "")
id_like = os_release.get("ID_LIKE", "").lower().split()

# Normalize distro IDs
if distro_id in known_base_distros:
return distro_id, version
elif distro_id in rhel_family:
# Oracle Linux (ol), Amazon Linux (amzn)
return "rhel", version
elif distro_id in suse_family:
return "suse", version

# Check ID_LIKE for derivative distributions (e.g., Pop!_OS, Raspbian, Mint)
if id_like:
for parent in id_like:
if parent in known_base_distros:
return parent, version
elif parent in rhel_family:
return "rhel", version
elif parent in suse_family:
return "suse", version

# Return the raw distro_id if we couldn't normalize it
return distro_id, version

# Fallback: check for specific files
if Path("/etc/debian_version").exists():
Expand All @@ -112,44 +131,31 @@ def _detect_cloud_provider() -> Optional[str]:
"""
# AWS detection
# Check for EC2 metadata
if Path("/sys/hypervisor/uuid").exists():
try:
with open("/sys/hypervisor/uuid") as f:
uuid = f.read().strip()
if uuid.startswith("ec2") or uuid.startswith("EC2"):
return "aws"
except OSError:
pass
uuid = _read_probe_file(Path("/sys/hypervisor/uuid"))
if uuid and uuid.strip().lower().startswith("ec2"):
return "aws"

# Check AWS environment variables
if os.getenv("AWS_EXECUTION_ENV") or os.getenv("AWS_REGION"):
return "aws"

# GCP detection
# Check for GCP metadata
if Path("/sys/class/dmi/id/product_name").exists():
try:
with open("/sys/class/dmi/id/product_name") as f:
product = f.read().strip()
if "Google" in product or "GCE" in product:
return "gcp"
except OSError:
pass
product = _read_probe_file(Path("/sys/class/dmi/id/product_name"))
if product:
product = product.strip()
if "Google" in product or "GCE" in product:
return "gcp"

# Check GCP environment variables
if os.getenv("GOOGLE_CLOUD_PROJECT") or os.getenv("GCP_PROJECT"):
return "gcp"

# Azure detection
if Path("/sys/class/dmi/id/sys_vendor").exists():
try:
with open("/sys/class/dmi/id/sys_vendor") as f:
vendor = f.read().strip()
# Could be Azure or Hyper-V, check for Azure-specific
if "Microsoft Corporation" in vendor and Path("/var/lib/waagent").exists():
return "azure"
except OSError:
pass
vendor = _read_probe_file(Path("/sys/class/dmi/id/sys_vendor"))
# Could be Azure or Hyper-V, check for Azure-specific
if vendor and "Microsoft Corporation" in vendor.strip() and Path("/var/lib/waagent").exists():
return "azure"

# Check Azure environment variables
if os.getenv("AZURE_SUBSCRIPTION_ID") or os.getenv("WEBSITE_INSTANCE_ID"):
Expand All @@ -173,14 +179,9 @@ def _detect_container() -> tuple[bool, bool]:
is_docker = True

# Also check cgroup
if Path("/proc/1/cgroup").exists():
try:
with open("/proc/1/cgroup") as f:
cgroup_content = f.read()
if "docker" in cgroup_content or "containerd" in cgroup_content:
is_docker = True
except OSError:
pass
cgroup_content = _read_probe_file(Path("/proc/1/cgroup"))
if cgroup_content and ("docker" in cgroup_content or "containerd" in cgroup_content):
is_docker = True

# Kubernetes detection
if os.getenv("KUBERNETES_SERVICE_HOST"):
Expand All @@ -201,14 +202,11 @@ def _detect_wsl() -> bool:
return True

# Check /proc/version for Microsoft/WSL signatures
if Path("/proc/version").exists():
try:
with open("/proc/version") as f:
version_info = f.read().lower()
if "microsoft" in version_info or "wsl" in version_info:
return True
except OSError:
pass
version_info = _read_probe_file(Path("/proc/version"))
if version_info:
version_info = version_info.lower()
if "microsoft" in version_info or "wsl" in version_info:
return True

# Check for Windows filesystem mounts (WSL mounts Windows drives at /mnt/)
# This is less reliable but can catch WSL 1
Expand Down
104 changes: 98 additions & 6 deletions tests/test_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,29 @@
_detect_container,
_detect_linux_distro,
_detect_python_env,
_detect_wsl,
_has_sudo_access,
_read_probe_file,
detect_environment,
)


class TestProbeFileReads:
"""Test best-effort probe file reads."""

def test_read_probe_file_returns_none_when_missing(self, tmp_path: Path) -> None:
"""Missing probe files return None."""
assert _read_probe_file(tmp_path / "missing") is None

def test_read_probe_file_returns_none_when_unreadable(self, tmp_path: Path) -> None:
"""Unreadable probe files return None instead of raising."""
probe_file = tmp_path / "probe"
probe_file.write_text("value")

with mock.patch("builtins.open", side_effect=OSError("permission denied")):
assert _read_probe_file(probe_file) is None


class TestLinuxDistroDetection:
"""Test Linux distribution detection."""

Expand Down Expand Up @@ -128,6 +146,40 @@ def path_constructor(path_str: str) -> mock.Mock:
assert distro == "ubuntu"
assert version == "22.04"

def test_detect_linux_distro_skips_unreadable_os_release(self) -> None:
"""Unreadable /etc/os-release falls back to /usr/lib/os-release."""
etc_path = mock.Mock()
etc_path.exists.return_value = True

usr_path = mock.Mock()
usr_path.exists.return_value = True

def path_constructor(path_str: str) -> mock.Mock:
if path_str == "/etc/os-release":
return etc_path
elif path_str == "/usr/lib/os-release":
return usr_path
fallback_path = mock.Mock()
fallback_path.exists.return_value = False
return fallback_path

usr_open = mock.mock_open(read_data='ID=ubuntu\nVERSION_ID="22.04"')

def open_side_effect(path: mock.Mock) -> mock.MagicMock:
if path is etc_path:
raise OSError("permission denied")
if path is usr_path:
return usr_open()
raise AssertionError(f"unexpected probe path: {path!r}")

with (
mock.patch("promptfoo.environment.Path", side_effect=path_constructor),
mock.patch("builtins.open", side_effect=open_side_effect),
):
distro, version = _detect_linux_distro()
assert distro == "ubuntu"
assert version == "22.04"


class TestCloudProviderDetection:
"""Test cloud provider detection."""
Expand Down Expand Up @@ -184,6 +236,20 @@ def test_no_cloud_provider_detected(self) -> None:
provider = _detect_cloud_provider()
assert provider is None

def test_detect_cloud_provider_ignores_unreadable_probe_files(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Unreadable cloud metadata files fall back to environment variables."""
monkeypatch.setenv("GOOGLE_CLOUD_PROJECT", "my-project")

path_mock = mock.Mock()
path_mock.exists.return_value = True

with (
mock.patch("promptfoo.environment.Path", return_value=path_mock),
mock.patch("builtins.open", side_effect=OSError("permission denied")),
):
provider = _detect_cloud_provider()
assert provider == "gcp"


class TestContainerDetection:
"""Test container detection."""
Expand All @@ -204,6 +270,23 @@ def test_detect_container_returns_tuple(self) -> None:
assert isinstance(is_docker, bool)
assert isinstance(is_k8s, bool)

def test_detect_container_ignores_unreadable_cgroup(self) -> None:
"""Unreadable cgroup metadata does not raise."""

def path_constructor(path_str: str) -> mock.Mock:
path_mock = mock.Mock()
path_mock.exists.return_value = path_str == "/proc/1/cgroup"
return path_mock

with (
mock.patch("promptfoo.environment.Path", side_effect=path_constructor),
mock.patch("builtins.open", side_effect=OSError("permission denied")),
mock.patch.dict(os.environ, {}, clear=True),
):
is_docker, is_k8s = _detect_container()
assert is_docker is False
assert is_k8s is False


class TestWSLDetection:
"""Test WSL detection."""
Expand All @@ -212,28 +295,37 @@ def test_detect_wsl_from_env_var(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Detect WSL from WSL_DISTRO_NAME environment variable."""
monkeypatch.setenv("WSL_DISTRO_NAME", "Ubuntu")

from promptfoo.environment import _detect_wsl

assert _detect_wsl() is True

def test_detect_wsl_from_interop_env(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Detect WSL from WSL_INTEROP environment variable."""
monkeypatch.setenv("WSL_INTEROP", "/run/WSL/123_interop")

from promptfoo.environment import _detect_wsl

assert _detect_wsl() is True

def test_no_wsl_detected(self) -> None:
"""Return False when not in WSL."""
with mock.patch.dict(os.environ, {}, clear=True):
from promptfoo.environment import _detect_wsl

# This will return False unless we're actually in WSL
# Just verify it returns a boolean
result = _detect_wsl()
assert isinstance(result, bool)

def test_detect_wsl_ignores_unreadable_proc_version(self) -> None:
"""Unreadable /proc/version does not raise."""

def path_constructor(path_str: str) -> mock.Mock:
path_mock = mock.Mock()
path_mock.exists.return_value = path_str == "/proc/version"
return path_mock

with (
mock.patch("promptfoo.environment.Path", side_effect=path_constructor),
mock.patch("builtins.open", side_effect=OSError("permission denied")),
mock.patch.dict(os.environ, {}, clear=True),
):
assert _detect_wsl() is False


class TestCIDetection:
"""Test CI/CD platform detection."""
Expand Down
Loading