Skip to content

Commit f96837f

Browse files
committed
fix: address telemetry and environment quality findings
1 parent 9750814 commit f96837f

3 files changed

Lines changed: 82 additions & 10 deletions

File tree

src/promptfoo/telemetry.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import os
1010
import platform
1111
import sys
12+
import threading
1213
import uuid
1314
from pathlib import Path
1415
from typing import Any
@@ -82,9 +83,10 @@ def _write_global_config(config: dict[str, Any]) -> None:
8283
pass # Silently fail - telemetry should never break the CLI
8384

8485

85-
def _get_user_id() -> str:
86+
def _get_user_id(config: dict[str, Any] | None = None) -> str:
8687
"""Get or create a unique user ID stored in the global config."""
87-
config = _read_global_config()
88+
if config is None:
89+
config = _read_global_config()
8890
user_id = config.get("id")
8991

9092
if not user_id:
@@ -95,9 +97,10 @@ def _get_user_id() -> str:
9597
return user_id
9698

9799

98-
def _get_user_email() -> str | None:
100+
def _get_user_email(config: dict[str, Any] | None = None) -> str | None:
99101
"""Get the user email from the global config if set."""
100-
config = _read_global_config()
102+
if config is None:
103+
config = _read_global_config()
101104
account = config.get("account", {})
102105
return account.get("email") if isinstance(account, dict) else None
103106

@@ -127,8 +130,9 @@ def _ensure_initialized(self) -> None:
127130
return
128131

129132
try:
130-
self._user_id = _get_user_id()
131-
self._email = _get_user_email()
133+
config = _read_global_config()
134+
self._user_id = _get_user_id(config)
135+
self._email = _get_user_email(config)
132136
self._client = Posthog(
133137
project_api_key=_POSTHOG_KEY,
134138
host=_POSTHOG_HOST,
@@ -182,15 +186,20 @@ def shutdown(self) -> None:
182186

183187
# Global singleton instance
184188
_telemetry: _Telemetry | None = None
189+
_telemetry_lock = threading.Lock()
185190

186191

187192
def _get_telemetry() -> _Telemetry:
188193
"""Get the global telemetry instance."""
189194
global _telemetry
190-
if _telemetry is None:
191-
_telemetry = _Telemetry()
192-
atexit.register(_telemetry.shutdown)
193-
return _telemetry
195+
if _telemetry is not None:
196+
return _telemetry
197+
198+
with _telemetry_lock:
199+
if _telemetry is None:
200+
_telemetry = _Telemetry()
201+
atexit.register(_telemetry.shutdown)
202+
return _telemetry
194203

195204

196205
def record_wrapper_used(method: str) -> None:

tests/test_environment.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,13 @@ def test_read_probe_file_returns_none_when_missing(self, tmp_path: Path) -> None
3131
"""Missing probe files return None."""
3232
assert _read_probe_file(tmp_path / "missing") is None
3333

34+
def test_read_probe_file_returns_content_when_readable(self, tmp_path: Path) -> None:
35+
"""Readable probe files return their text content."""
36+
probe_file = tmp_path / "probe"
37+
probe_file.write_text("value")
38+
39+
assert _read_probe_file(probe_file) == "value"
40+
3441
def test_read_probe_file_returns_none_when_unreadable(self, tmp_path: Path) -> None:
3542
"""Unreadable probe files return None instead of raising."""
3643
probe_file = tmp_path / "probe"
@@ -245,6 +252,7 @@ def test_detect_kubernetes_from_env(self, monkeypatch: pytest.MonkeyPatch) -> No
245252
mock_path.return_value.exists.return_value = False
246253

247254
is_docker, is_k8s = _detect_container()
255+
assert is_docker is False
248256
assert is_k8s is True
249257

250258
def test_detect_container_returns_tuple(self) -> None:

tests/test_telemetry.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
"""
1010

1111
import os
12+
import threading
1213
from pathlib import Path
1314
from unittest import mock
1415

@@ -17,11 +18,13 @@
1718
from promptfoo.telemetry import (
1819
_get_config_dir,
1920
_get_env_bool,
21+
_get_telemetry,
2022
_get_user_email,
2123
_get_user_id,
2224
_is_ci,
2325
_read_global_config,
2426
_Telemetry,
27+
_telemetry_lock,
2528
_write_global_config,
2629
record_wrapper_used,
2730
)
@@ -268,6 +271,23 @@ def test_record_initializes_client(self, tmp_path: Path) -> None:
268271
assert telemetry._client is mock_client
269272
mock_client.capture.assert_called_once()
270273

274+
def test_initialization_reads_global_config_once(self) -> None:
275+
"""Initialization shares one config read across user identity lookups."""
276+
config = {"id": "test-user-id", "account": {"email": "test@example.com"}}
277+
278+
with (
279+
mock.patch.dict(os.environ, {}, clear=True),
280+
mock.patch("promptfoo.telemetry._read_global_config", return_value=config) as mock_read_config,
281+
mock.patch("promptfoo.telemetry.Posthog") as mock_posthog,
282+
):
283+
telemetry = _Telemetry()
284+
telemetry._ensure_initialized()
285+
286+
mock_read_config.assert_called_once_with()
287+
assert telemetry._user_id == "test-user-id"
288+
assert telemetry._email == "test@example.com"
289+
mock_posthog.assert_called_once()
290+
271291
def test_record_enriches_properties(self, tmp_path: Path) -> None:
272292
"""Test record adds enriched properties."""
273293
config_file = tmp_path / "promptfoo.yaml"
@@ -471,3 +491,38 @@ def test_record_wrapper_used_disabled(self, monkeypatch: pytest.MonkeyPatch) ->
471491
with mock.patch("promptfoo.telemetry._telemetry", None):
472492
# Should not raise or make any calls
473493
record_wrapper_used("global")
494+
495+
def test_get_telemetry_guards_singleton_initialization_with_lock(self) -> None:
496+
"""Singleton construction waits on its lock and registers shutdown once."""
497+
started = threading.Event()
498+
finished = threading.Event()
499+
instance = mock.Mock(spec=_Telemetry)
500+
results: list[_Telemetry] = []
501+
502+
def initialize() -> None:
503+
started.set()
504+
results.append(_get_telemetry())
505+
finished.set()
506+
507+
with (
508+
mock.patch("promptfoo.telemetry._telemetry", None),
509+
mock.patch("promptfoo.telemetry._Telemetry", return_value=instance) as mock_telemetry,
510+
mock.patch("promptfoo.telemetry.atexit.register") as mock_register,
511+
):
512+
_telemetry_lock.acquire()
513+
try:
514+
worker = threading.Thread(target=initialize)
515+
worker.start()
516+
assert started.wait(timeout=1)
517+
assert finished.wait(timeout=0.05) is False
518+
mock_telemetry.assert_not_called()
519+
finally:
520+
_telemetry_lock.release()
521+
522+
worker.join(timeout=1)
523+
assert worker.is_alive() is False
524+
525+
assert results == [instance]
526+
assert _get_telemetry() is instance
527+
mock_telemetry.assert_called_once_with()
528+
mock_register.assert_called_once_with(instance.shutdown)

0 commit comments

Comments
 (0)