Skip to content

Commit 5e97386

Browse files
abrichrclaude
andcommitted
fix: don't take a screenshot at import time (breaks headless imports)
recorder.py computed monitor dimensions via utils.take_screenshot() at module scope, so `import openadapt_capture` crashed in any headless environment whose display reported a zero-size region. This took down `openadapt version` and `openadapt doctor` (found via the new CLI smoke tests in OpenAdaptAI/OpenAdapt). Move the computation into the video- setup function that is its only consumer. Adds tests/test_headless_import.py: a deterministic AST guard that no package module calls a display API (take_screenshot/get_monitor_dims/ grab) at import scope. A subprocess import test is unreliable (only reproduces on a genuinely headless display); this fails regardless of environment. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
1 parent cd7ed78 commit 5e97386

2 files changed

Lines changed: 100 additions & 63 deletions

File tree

openadapt_capture/recorder.py

Lines changed: 42 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,8 @@
4545
soundfile = None
4646
websockets = None
4747

48-
def set_browser_mode(
49-
mode: str, websocket: "websockets.sync.server.ServerConnection"
50-
) -> None:
48+
49+
def set_browser_mode(mode: str, websocket: "websockets.sync.server.ServerConnection") -> None:
5150
"""Send a message to the browser extension to set the mode."""
5251
logger.info(f"{type(websocket)=}")
5352
VALID_MODES = ("idle", "record", "replay")
@@ -142,9 +141,6 @@ def __bool__(self):
142141
stop_sequence_detected = False
143142
ws_server_instance = None
144143

145-
# TODO XXX replace with utils.get_monitor_dims() once fixed
146-
monitor_width, monitor_height = utils.take_screenshot().size
147-
148144

149145
def collect_stats(performance_snapshots: list[tracemalloc.Snapshot]) -> None:
150146
"""Collects and appends performance snapshots using tracemalloc.
@@ -533,7 +529,9 @@ def write_events(
533529

534530

535531
def video_pre_callback(
536-
db: crud.SaSession, recording: Recording, video_dir: str = None,
532+
db: crud.SaSession,
533+
recording: Recording,
534+
video_dir: str = None,
537535
) -> dict[str, Any]:
538536
"""Function to call before main loop.
539537
@@ -546,8 +544,10 @@ def video_pre_callback(
546544
dict[str, Any]: The updated state.
547545
"""
548546
video_file_path = video.get_video_file_path(recording.timestamp, video_dir)
549-
video_container, video_stream, video_start_timestamp = (
550-
video.initialize_video_writer(video_file_path, monitor_width, monitor_height)
547+
# TODO XXX replace with utils.get_monitor_dims() once fixed
548+
monitor_width, monitor_height = utils.take_screenshot().size
549+
video_container, video_stream, video_start_timestamp = video.initialize_video_writer(
550+
video_file_path, monitor_width, monitor_height
551551
)
552552
crud.update_video_start_time(db, recording, video_start_timestamp)
553553
return {
@@ -643,9 +643,7 @@ def write_video_event(
643643
}
644644

645645

646-
def trigger_action_event(
647-
event_q: queue.Queue, action_event_args: dict[str, Any]
648-
) -> None:
646+
def trigger_action_event(event_q: queue.Queue, action_event_args: dict[str, Any]) -> None:
649647
"""Triggers an action event and adds it to the event queue.
650648
651649
Args:
@@ -778,9 +776,7 @@ def handle_key(
778776
"char",
779777
"vk",
780778
]
781-
attrs = {
782-
f"key_{attr_name}": getattr(key, attr_name, None) for attr_name in attr_names
783-
}
779+
attrs = {f"key_{attr_name}": getattr(key, attr_name, None) for attr_name in attr_names}
784780
logger.debug(f"{attrs=}")
785781
canonical_attrs = {
786782
f"canonical_key_{attr_name}": getattr(canonical_key, attr_name, None)
@@ -1216,9 +1212,7 @@ def audio_callback(
12161212
audio_frames.append(indata.copy())
12171213

12181214
# open InputStream and start recording while ActionEvents are recorded
1219-
audio_stream = sounddevice.InputStream(
1220-
callback=audio_callback, samplerate=16000, channels=1
1221-
)
1215+
audio_stream = sounddevice.InputStream(callback=audio_callback, samplerate=16000, channels=1)
12221216
logger.info("Audio recording started.")
12231217
start_timestamp = utils.get_timestamp()
12241218
audio_stream.start()
@@ -1239,6 +1233,7 @@ def audio_callback(
12391233
# Convert audio to text using OpenAI's Whisper
12401234
logger.info("Transcribing audio...")
12411235
import whisper
1236+
12421237
model = whisper.load_model("base")
12431238
result_info = model.transcribe(converted_audio, word_timestamps=True, fp16=False)
12441239
logger.info(f"The narrated text is: {result_info['text']}")
@@ -1251,21 +1246,15 @@ def audio_callback(
12511246
word_list = result_info["segments"][0]["words"]
12521247

12531248
# compress and convert to bytes to save to database
1254-
logger.info(
1255-
"Size of uncompressed audio data: {} bytes".format(converted_audio.nbytes)
1256-
)
1249+
logger.info("Size of uncompressed audio data: {} bytes".format(converted_audio.nbytes))
12571250
# Create an in-memory file-like object
12581251
file_obj = io.BytesIO()
12591252
# Write the audio data using lossless compression
1260-
soundfile.write(
1261-
file_obj, converted_audio, int(audio_stream.samplerate), format="FLAC"
1262-
)
1253+
soundfile.write(file_obj, converted_audio, int(audio_stream.samplerate), format="FLAC")
12631254
# Get the compressed audio data as bytes
12641255
compressed_audio_bytes = file_obj.getvalue()
12651256

1266-
logger.info(
1267-
"Size of compressed audio data: {} bytes".format(len(compressed_audio_bytes))
1268-
)
1257+
logger.info("Size of compressed audio data: {} bytes".format(len(compressed_audio_bytes)))
12691258

12701259
file_obj.close()
12711260

@@ -1462,9 +1451,7 @@ def record(
14621451
event_q,
14631452
terminate_processing,
14641453
recording,
1465-
task_started_events.setdefault(
1466-
"window_event_reader", threading.Event()
1467-
),
1454+
task_started_events.setdefault("window_event_reader", threading.Event()),
14681455
),
14691456
)
14701457
window_event_reader.start()
@@ -1477,9 +1464,7 @@ def record(
14771464
event_q,
14781465
terminate_processing,
14791466
recording,
1480-
task_started_events.setdefault(
1481-
"browser_event_reader", threading.Event()
1482-
),
1467+
task_started_events.setdefault("browser_event_reader", threading.Event()),
14831468
),
14841469
)
14851470
browser_event_reader.start()
@@ -1567,9 +1552,7 @@ def record(
15671552
recording,
15681553
db_path,
15691554
terminate_processing,
1570-
task_started_events.setdefault(
1571-
"screen_event_writer", multiprocessing.Event()
1572-
),
1555+
task_started_events.setdefault("screen_event_writer", multiprocessing.Event()),
15731556
),
15741557
)
15751558
screen_event_writer.start()
@@ -1587,9 +1570,7 @@ def record(
15871570
recording,
15881571
db_path,
15891572
terminate_processing,
1590-
task_started_events.setdefault(
1591-
"browser_event_writer", multiprocessing.Event()
1592-
),
1573+
task_started_events.setdefault("browser_event_writer", multiprocessing.Event()),
15931574
),
15941575
)
15951576
browser_event_writer.start()
@@ -1606,9 +1587,7 @@ def record(
16061587
recording,
16071588
db_path,
16081589
terminate_processing,
1609-
task_started_events.setdefault(
1610-
"action_event_writer", multiprocessing.Event()
1611-
),
1590+
task_started_events.setdefault("action_event_writer", multiprocessing.Event()),
16121591
),
16131592
)
16141593
action_event_writer.start()
@@ -1626,9 +1605,7 @@ def record(
16261605
recording,
16271606
db_path,
16281607
terminate_processing,
1629-
task_started_events.setdefault(
1630-
"window_event_writer", multiprocessing.Event()
1631-
),
1608+
task_started_events.setdefault("window_event_writer", multiprocessing.Event()),
16321609
),
16331610
)
16341611
window_event_writer.start()
@@ -1661,9 +1638,7 @@ def record(
16611638
recording,
16621639
db_path,
16631640
terminate_processing,
1664-
task_started_events.setdefault(
1665-
"audio_event_writer", multiprocessing.Event()
1666-
),
1641+
task_started_events.setdefault("audio_event_writer", multiprocessing.Event()),
16671642
),
16681643
)
16691644
audio_recorder.start()
@@ -1677,9 +1652,7 @@ def record(
16771652
recording,
16781653
db_path,
16791654
terminate_perf_event,
1680-
task_started_events.setdefault(
1681-
"perf_stats_writer", multiprocessing.Event()
1682-
),
1655+
task_started_events.setdefault("perf_stats_writer", multiprocessing.Event()),
16831656
),
16841657
)
16851658
perf_stats_writer.start()
@@ -1715,9 +1688,7 @@ def record(
17151688
started_tasks = sum(event.is_set() for event in task_started_events.values())
17161689
if started_tasks >= expected_starts:
17171690
break
1718-
waiting_for = [
1719-
task for task, event in task_started_events.items() if not event.is_set()
1720-
]
1691+
waiting_for = [task for task, event in task_started_events.items() if not event.is_set()]
17211692
logger.info(f"Waiting for tasks to start: {waiting_for}")
17221693
logger.info(f"Started tasks: {started_tasks}/{expected_starts}")
17231694
time.sleep(1) # Sleep to reduce busy waiting
@@ -1780,7 +1751,9 @@ def join_tasks(task_names: list[str]) -> None:
17801751
if config.PLOT_PERFORMANCE:
17811752
session = get_session_for_path(db_path)
17821753
plotting.plot_performance(
1783-
session, recording, save_dir=capture_dir,
1754+
session,
1755+
recording,
1756+
save_dir=capture_dir,
17841757
)
17851758

17861759
logger.info(f"Saved {recording_timestamp=}")
@@ -1824,6 +1797,7 @@ def join_tasks(task_names: list[str]) -> None:
18241797
_profile_path = os.path.join(capture_dir, "profiling.json")
18251798
try:
18261799
import json as _json
1800+
18271801
with open(_profile_path, "w") as _f:
18281802
_json.dump(_profile_data, _f, indent=2)
18291803
logger.info(f"Profiling saved to {_profile_path}")
@@ -1838,13 +1812,17 @@ def join_tasks(task_names: list[str]) -> None:
18381812
print(f" {k}: {v} events ({rate:.1f}/s)")
18391813
if _screen_timing:
18401814
st = _profile_data["screen_timing"]
1841-
print(f" screenshot: avg={st['screenshot_avg_ms']}ms "
1842-
f"max={st['screenshot_max_ms']}ms "
1843-
f"min={st['screenshot_min_ms']}ms")
1844-
print(f"Config: WINDOW_DATA={config.RECORD_WINDOW_DATA} "
1845-
f"VIDEO={config.RECORD_VIDEO} "
1846-
f"PLOT_PERF={config.PLOT_PERFORMANCE} "
1847-
f"FPS={config.SCREEN_CAPTURE_FPS}")
1815+
print(
1816+
f" screenshot: avg={st['screenshot_avg_ms']}ms "
1817+
f"max={st['screenshot_max_ms']}ms "
1818+
f"min={st['screenshot_min_ms']}ms"
1819+
)
1820+
print(
1821+
f"Config: WINDOW_DATA={config.RECORD_WINDOW_DATA} "
1822+
f"VIDEO={config.RECORD_VIDEO} "
1823+
f"PLOT_PERF={config.PLOT_PERFORMANCE} "
1824+
f"FPS={config.SCREEN_CAPTURE_FPS}"
1825+
)
18481826
print("=========================\n")
18491827

18501828
# Auto-send profiling via wormhole if requested
@@ -1977,7 +1955,8 @@ def _run_record(self) -> None:
19771955
def __enter__(self) -> "Recorder":
19781956
# Start status drain thread
19791957
self._status_thread = threading.Thread(
1980-
target=self._drain_status_pipe, daemon=True,
1958+
target=self._drain_status_pipe,
1959+
daemon=True,
19811960
)
19821961
self._status_thread.start()
19831962

tests/test_headless_import.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
"""Importing openadapt_capture must not call display APIs at module scope.
2+
3+
recorder.py used to compute monitor dimensions via a screenshot at
4+
module scope (`monitor_width, monitor_height = utils.take_screenshot()
5+
.size`), so `import openadapt_capture` crashed in any headless
6+
environment whose display reported a zero-size region, taking down
7+
`openadapt version`/`doctor` with it.
8+
9+
A subprocess import test is unreliable here (it only reproduces on a
10+
genuinely headless display), so this guards the invariant statically:
11+
no module-level call to a display/screenshot API in any package module.
12+
Importing a library should be cheap and side-effect free.
13+
"""
14+
15+
from __future__ import annotations
16+
17+
import ast
18+
from pathlib import Path
19+
20+
PACKAGE_ROOT = Path(__file__).resolve().parent.parent / "openadapt_capture"
21+
22+
# Calls that touch the screen/display and must not run at import time.
23+
FORBIDDEN_AT_MODULE_SCOPE = {"take_screenshot", "get_monitor_dims", "grab"}
24+
25+
26+
def _module_level_calls(tree: ast.Module):
27+
"""Yield Call nodes that execute at import (module body, not inside
28+
a function or class definition)."""
29+
for node in tree.body:
30+
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
31+
continue
32+
for sub in ast.walk(node):
33+
if isinstance(sub, ast.Call):
34+
yield sub
35+
36+
37+
def _called_name(call: ast.Call) -> str | None:
38+
func = call.func
39+
if isinstance(func, ast.Name):
40+
return func.id
41+
if isinstance(func, ast.Attribute):
42+
return func.attr
43+
return None
44+
45+
46+
def test_no_display_calls_at_module_scope():
47+
problems = []
48+
for path in PACKAGE_ROOT.rglob("*.py"):
49+
tree = ast.parse(path.read_text(encoding="utf-8"), filename=str(path))
50+
for call in _module_level_calls(tree):
51+
name = _called_name(call)
52+
if name in FORBIDDEN_AT_MODULE_SCOPE:
53+
problems.append(f"{path}:{call.lineno}: module-level {name}()")
54+
assert not problems, (
55+
"Display/screenshot calls at import time break headless imports "
56+
"(CI, servers, containers). Move them inside the function that "
57+
"uses them:\n " + "\n ".join(problems)
58+
)

0 commit comments

Comments
 (0)