Skip to content

Commit 79fbc2a

Browse files
authored
fix: correct detection of wrapper shim on Windows/venv (#5)
* fix: improve detection of wrapper shim on Windows * fix: address linting errors * style: apply ruff formatting * refactor: improve code clarity and professionalism
1 parent ced6a8d commit 79fbc2a

File tree

1 file changed

+73
-30
lines changed

1 file changed

+73
-30
lines changed

src/promptfoo/cli.py

Lines changed: 73 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,19 @@ def print_installation_help() -> None:
3939

4040

4141
def _normalize_path(path: str) -> str:
42+
"""Normalize a path for safe comparison."""
4243
return os.path.normcase(os.path.abspath(path))
4344

4445

4546
def _strip_quotes(path: str) -> str:
47+
"""Strip surrounding quotes from a path string."""
4648
if len(path) >= 2 and path[0] == path[-1] and path[0] in ('"', "'"):
4749
return path[1:-1]
4850
return path
4951

5052

5153
def _split_path(path_value: str) -> list[str]:
54+
"""Split a PATH string into a list of directories."""
5255
entries = []
5356
for entry in path_value.split(os.pathsep):
5457
entry = _strip_quotes(entry.strip())
@@ -58,6 +61,7 @@ def _split_path(path_value: str) -> list[str]:
5861

5962

6063
def _resolve_argv0() -> Optional[str]:
64+
"""Resolve the absolute path of the current script (argv[0])."""
6165
if not sys.argv:
6266
return None
6367
argv0 = sys.argv[0]
@@ -72,59 +76,98 @@ def _resolve_argv0() -> Optional[str]:
7276

7377

7478
def _find_windows_promptfoo() -> Optional[str]:
75-
candidates = []
79+
"""
80+
Search for promptfoo in standard Windows installation locations.
81+
Useful when not in PATH.
82+
"""
83+
search_dirs = []
84+
85+
# Check npm config env vars
7686
for key in ("NPM_CONFIG_PREFIX", "npm_config_prefix"):
77-
prefix = os.environ.get(key)
78-
if prefix:
79-
candidates.append(prefix)
80-
appdata = os.environ.get("APPDATA")
81-
if appdata:
82-
candidates.append(os.path.join(appdata, "npm"))
83-
localappdata = os.environ.get("LOCALAPPDATA")
84-
if localappdata:
85-
candidates.append(os.path.join(localappdata, "npm"))
87+
if prefix := os.environ.get(key):
88+
search_dirs.append(prefix)
89+
90+
# Check standard npm folders
91+
if appdata := os.environ.get("APPDATA"):
92+
search_dirs.append(os.path.join(appdata, "npm"))
93+
if localappdata := os.environ.get("LOCALAPPDATA"):
94+
search_dirs.append(os.path.join(localappdata, "npm"))
95+
96+
# Check Program Files
8697
for env_key in ("ProgramFiles", "ProgramFiles(x86)"):
87-
program_files = os.environ.get(env_key)
88-
if program_files:
89-
candidates.append(os.path.join(program_files, "nodejs"))
90-
for base in candidates:
98+
if program_files := os.environ.get(env_key):
99+
search_dirs.append(os.path.join(program_files, "nodejs"))
100+
101+
for base_dir in search_dirs:
91102
for name in ("promptfoo.cmd", "promptfoo.exe"):
92-
candidate = os.path.join(base, name)
103+
candidate = os.path.join(base_dir, name)
93104
if os.path.isfile(candidate):
94105
return candidate
95106
return None
96107

97108

109+
def _is_executing_wrapper(found_path: str) -> bool:
110+
"""
111+
Detect if the found executable is actually this wrapper script.
112+
113+
This handles cases where the wrapper is installed in the same bin/ directory
114+
as the target or if we are inside a virtual environment.
115+
"""
116+
argv0_path = _resolve_argv0()
117+
found_norm = _normalize_path(found_path)
118+
119+
# direct argv0 match
120+
if argv0_path and found_norm == argv0_path:
121+
return True
122+
123+
# venv detection (shim check)
124+
return sys.prefix != sys.base_prefix and os.path.dirname(found_norm) == os.path.dirname(
125+
_normalize_path(sys.executable)
126+
)
127+
128+
129+
def _search_path_excluding(exclude_dir: str) -> Optional[str]:
130+
"""Search PATH for promptfoo, excluding the specified directory."""
131+
path_entries = [entry for entry in _split_path(os.environ.get("PATH", "")) if _normalize_path(entry) != exclude_dir]
132+
if not path_entries:
133+
return None
134+
return shutil.which("promptfoo", path=os.pathsep.join(path_entries))
135+
136+
98137
def _find_external_promptfoo() -> Optional[str]:
99-
promptfoo_path = shutil.which("promptfoo")
100-
if not promptfoo_path:
138+
"""Find the external promptfoo executable, avoiding the wrapper itself."""
139+
# 1. First naive search
140+
candidate = shutil.which("promptfoo")
141+
142+
# 2. If not found, try explicit Windows paths
143+
if not candidate:
101144
if os.name == "nt":
102145
return _find_windows_promptfoo()
103146
return None
104-
argv0_path = _resolve_argv0()
105-
if argv0_path and _normalize_path(promptfoo_path) == argv0_path:
106-
wrapper_dir = _normalize_path(os.path.dirname(promptfoo_path))
107-
path_entries = [
108-
entry for entry in _split_path(os.environ.get("PATH", "")) if _normalize_path(entry) != wrapper_dir
109-
]
110-
if path_entries:
111-
candidate = shutil.which("promptfoo", path=os.pathsep.join(path_entries))
112-
if candidate:
113-
return candidate
114-
if os.name == "nt":
147+
148+
# 3. If found, check if it's us (the wrapper)
149+
if _is_executing_wrapper(candidate):
150+
wrapper_dir = _normalize_path(os.path.dirname(candidate))
151+
# Search again excluding our directory
152+
candidate = _search_path_excluding(wrapper_dir)
153+
154+
# If still not found, try Windows fallback
155+
if not candidate and os.name == "nt":
115156
return _find_windows_promptfoo()
116-
return None
117-
return promptfoo_path
157+
158+
return candidate
118159

119160

120161
def _requires_shell(executable: str) -> bool:
162+
"""Check if the executable requires a shell to run (Windows only)."""
121163
if os.name != "nt":
122164
return False
123165
_, ext = os.path.splitext(executable)
124166
return ext.lower() in _WINDOWS_SHELL_EXTENSIONS
125167

126168

127169
def _run_command(cmd: list[str], env: Optional[dict[str, str]] = None) -> subprocess.CompletedProcess:
170+
"""Execute a command, handling shell requirements on Windows."""
128171
if _requires_shell(cmd[0]):
129172
return subprocess.run(subprocess.list2cmdline(cmd), shell=True, env=env)
130173
return subprocess.run(cmd, env=env)

0 commit comments

Comments
 (0)