diff --git a/openadapt_evals/agents/demo_executor.py b/openadapt_evals/agents/demo_executor.py index 1e92157..964cc4a 100644 --- a/openadapt_evals/agents/demo_executor.py +++ b/openadapt_evals/agents/demo_executor.py @@ -4,9 +4,9 @@ interpret them. The planner is only consulted as a recovery mechanism when the expected screen state doesn't match. -Tier 1 (deterministic): keyboard shortcuts, typing — execute directly. -Tier 2 (grounder-only): clicks — grounder finds element by description. -Tier 3 (planner recovery): unexpected state — planner reasons about +Tier 1 (deterministic): keyboard shortcuts, typing -- execute directly. +Tier 2 (grounder-only): clicks -- grounder finds element by description. +Tier 3 (planner recovery): unexpected state -- planner reasons about what to do when the demo doesn't match reality. Usage: @@ -29,6 +29,7 @@ from openadapt_evals.adapters.base import BenchmarkAction, BenchmarkObservation from openadapt_evals.demo_library import Demo, DemoStep +from openadapt_evals.grounding import check_state_preconditions, verify_transition try: from openadapt_evals.integrations.weave_integration import weave_op @@ -94,7 +95,7 @@ def run( screenshot_dir: Optional directory to save screenshots. Returns: - (score, screenshots) — score from evaluate_dense(). + (score, screenshots) -- score from evaluate_dense(). """ from openadapt_evals.adapters.rl_env import ResetConfig @@ -121,13 +122,32 @@ def run( for i, step in enumerate(demo.steps): logger.info( - "Demo step %d/%d: %s %s — %s", + "Demo step %d/%d: %s %s -- %s", i + 1, len(demo.steps), step.action_type, step.action_value or "", step.description, ) + # Phase 4: Pre-click state narrowing + if ( + step.action_type in ("click", "double_click") + and step.grounding_target is not None + and obs.screenshot + ): + ok, reason = check_state_preconditions( + obs.screenshot, + step.grounding_target, + ocr_fn=None, + ) + if not ok: + logger.warning( + "Step %d: state precondition failed: %s", + i + 1, reason, + ) + # Observational in Phase 4 -- proceed anyway. + # Blocking / state recovery deferred to later phase. + action = self._execute_step(step, obs) if action is None: logger.warning("Step %d: no action produced, skipping", i + 1) @@ -144,6 +164,23 @@ def run( step_result = self._dispatch_action(env, action) obs = step_result.observation + # Phase 4: Post-click transition verification + if ( + step.action_type in ("click", "double_click") + and step.grounding_target is not None + and obs.screenshot + ): + ok, reason = verify_transition( + obs.screenshot, + step.grounding_target, + ocr_fn=None, + ) + if not ok: + logger.warning( + "Step %d: transition verification failed: %s", + i + 1, reason, + ) + if obs.screenshot: screenshots.append(obs.screenshot) if screenshot_dir: @@ -190,9 +227,9 @@ def _execute_step( ) -> BenchmarkAction | None: """Produce an action for a demo step using tiered intelligence. - Tier 1: keyboard/type → direct execution (no VLM). - Tier 2: click → grounder finds element by description. - Tier 3: recovery → planner reasons about unexpected state. + Tier 1: keyboard/type -> direct execution (no VLM). + Tier 2: click -> grounder finds element by description. + Tier 3: recovery -> planner reasons about unexpected state. """ if step.action_type == "key": # Tier 1: deterministic keyboard action @@ -230,7 +267,7 @@ def _execute_step( ) return action - # Unknown action type — log and skip + # Unknown action type -- log and skip logger.warning("Unknown action type %r, skipping", step.action_type) return None @@ -306,7 +343,7 @@ def _ground_click_http( logger.info("HTTP grounder: %s", raw[:200]) - # Parse [x1,y1,x2,y2] bbox → center click + # Parse [x1,y1,x2,y2] bbox -> center click from openadapt_evals.agents.planner_grounder_agent import ( PlannerGrounderAgent, ) @@ -345,7 +382,7 @@ def _ground_click_vlm( if action.type == "done": logger.warning( - "Grounder could not find %r — returning click at center", + "Grounder could not find %r -- returning click at center", description, ) return BenchmarkAction(type="click", x=0.5, y=0.5) diff --git a/openadapt_evals/grounding.py b/openadapt_evals/grounding.py index 5e50cab..1d62422 100644 --- a/openadapt_evals/grounding.py +++ b/openadapt_evals/grounding.py @@ -1,15 +1,24 @@ -"""Grounding data model for the DemoExecutor cascade. +"""Grounding data model and state verification for the DemoExecutor cascade. Defines GroundingTarget (stored per click step in demo) and GroundingCandidate (produced by each tier during grounding). +Also provides state-narrowing functions (Phase 4 of the cascade): +- ``check_state_preconditions``: verify the screen matches expectations + before grounding a click. +- ``verify_transition``: verify the expected state change occurred after + clicking. + See docs/design/grounding_cascade_design_v3.md for the full architecture. """ from __future__ import annotations +import logging from dataclasses import dataclass, field -from typing import Any +from typing import Any, Callable + +logger = logging.getLogger(__name__) @dataclass @@ -87,3 +96,200 @@ class GroundingCandidate: spatial_score: float | None = None # consistency with demo position visual_verify_score: float | None = None # crop resemblance to target accepted: bool = False + + +# --------------------------------------------------------------------------- +# Phase 4: State narrowing -- pre-click and post-click verification +# --------------------------------------------------------------------------- + + +def _text_present( + query: str, + ocr_results: list[dict], + case_sensitive: bool = False, +) -> bool: + """Check whether *query* appears in any OCR result text. + + Args: + query: Text to search for. + ocr_results: List of dicts with at least a ``"text"`` key. + case_sensitive: Whether the comparison is case-sensitive. + + Returns: + ``True`` if *query* is a substring of any OCR result text. + """ + if not case_sensitive: + query = query.lower() + for item in ocr_results: + text = item.get("text", "") + if not case_sensitive: + text = text.lower() + if query in text: + return True + return False + + +def check_state_preconditions( + screenshot: bytes, + target: GroundingTarget, + ocr_fn: Callable[[bytes], list[dict]] | None = None, +) -> tuple[bool, str]: + """Check if the current screen state matches the demo's expectations. + + This is the "state narrowing" step that runs *before* candidate + generation. It is cheaper to detect "wrong screen" than to ground + on it -- see the Phase 4 rationale in + ``docs/design/grounding_cascade_design_v3.md``. + + Args: + screenshot: Current screenshot PNG bytes. + target: :class:`GroundingTarget` with ``window_title``, + ``nearby_text``, ``surrounding_labels``, etc. + ocr_fn: Optional OCR function that accepts PNG bytes and returns + ``list[dict]`` where each dict has at least a ``"text"`` key + (and optionally ``"bbox"``). When *None*, precondition + checks that require OCR are skipped gracefully. + + Returns: + ``(preconditions_met, reason)`` -- ``True`` if safe to proceed + with grounding, ``False`` with a human-readable reason string if + state recovery is needed. + """ + has_expectations = bool( + target.window_title + or target.nearby_text + or target.surrounding_labels + ) + + # No text expectations on this target -- nothing to check. + if not has_expectations: + return True, "no text preconditions defined on target" + + # OCR unavailable -- skip gracefully (Phase 5 adds real OCR). + if ocr_fn is None: + return True, "no OCR available, skipping precondition check" + + ocr_results = ocr_fn(screenshot) + + # 1. Window title check + if target.window_title: + if not _text_present(target.window_title, ocr_results): + return ( + False, + f"window title mismatch: expected {target.window_title!r}", + ) + + # 2. Nearby text -- require at least half to be present + if target.nearby_text: + found = sum( + 1 for t in target.nearby_text if _text_present(t, ocr_results) + ) + threshold = max(1, len(target.nearby_text) // 2) + if found < threshold: + return ( + False, + f"nearby text mismatch: found {found}/{len(target.nearby_text)}" + f" (need >= {threshold})", + ) + + # 3. Surrounding labels -- require at least half to be present + if target.surrounding_labels: + found = sum( + 1 + for t in target.surrounding_labels + if _text_present(t, ocr_results) + ) + threshold = max(1, len(target.surrounding_labels) // 2) + if found < threshold: + return ( + False, + f"surrounding labels mismatch: found " + f"{found}/{len(target.surrounding_labels)}" + f" (need >= {threshold})", + ) + + return True, "preconditions met" + + +def verify_transition( + screenshot_after: bytes, + target: GroundingTarget, + ocr_fn: Callable[[bytes], list[dict]] | None = None, +) -> tuple[bool, str]: + """Verify that the click produced the expected state change. + + Uses structured transition expectations from :class:`GroundingTarget`: + + - ``disappearance_text``: text that should *no longer* be visible. + - ``appearance_text``: text that should *now* be visible. + - ``window_title_change``: expected new window title. + - ``modal_toggled``: whether a modal appeared/disappeared (deferred + until a modal-detection backend is available). + + Args: + screenshot_after: Screenshot PNG bytes taken after the click. + target: :class:`GroundingTarget` with structured transition + expectations. + ocr_fn: Optional OCR function (same contract as + :func:`check_state_preconditions`). When *None*, checks + that require OCR are skipped gracefully. + + Returns: + ``(verified, reason)`` -- ``True`` if the transition looks + correct, ``False`` with a human-readable reason if it looks + wrong. + """ + has_expectations = bool( + target.disappearance_text + or target.appearance_text + or target.window_title_change is not None + or target.modal_toggled is not None + ) + + # No structured transition expectations -- nothing to verify. + if not has_expectations: + return True, "no transition expectations defined on target" + + # OCR unavailable -- skip gracefully. + if ocr_fn is None: + return True, "no OCR available, skipping transition verification" + + ocr_results = ocr_fn(screenshot_after) + + # 1. Disappearance check -- text should have vanished. + if target.disappearance_text: + for text in target.disappearance_text: + if _text_present(text, ocr_results): + return ( + False, + f"disappearance_text still present: {text!r}", + ) + + # 2. Appearance check -- text should now be visible. + if target.appearance_text: + for text in target.appearance_text: + if not _text_present(text, ocr_results): + return ( + False, + f"appearance_text not found: {text!r}", + ) + + # 3. Window title change + if target.window_title_change is not None: + if not _text_present(target.window_title_change, ocr_results): + return ( + False, + f"window title change not detected: " + f"expected {target.window_title_change!r}", + ) + + # 4. Modal toggled -- deferred (requires modal detection backend). + # Log for observability but do not fail. + if target.modal_toggled is not None: + logger.debug( + "modal_toggled=%s expectation set but no modal detection " + "backend available -- skipping", + target.modal_toggled, + ) + + return True, "transition verified" diff --git a/tests/test_grounding.py b/tests/test_grounding.py new file mode 100644 index 0000000..61a9a6a --- /dev/null +++ b/tests/test_grounding.py @@ -0,0 +1,329 @@ +"""Tests for grounding data model and Phase 4 state-narrowing functions. + +Covers: +- check_state_preconditions (pre-click state verification) +- verify_transition (post-click transition verification) +- GroundingTarget round-trip serialization +""" + +from __future__ import annotations + +from openadapt_evals.grounding import ( + GroundingTarget, + check_state_preconditions, + verify_transition, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +DUMMY_SCREENSHOT = b"\x89PNG\r\n\x1a\n" # minimal PNG header (content irrelevant) + + +def _make_ocr_fn(texts: list[str]): + """Return an ocr_fn that always reports the given texts.""" + + def ocr_fn(_screenshot_bytes: bytes) -> list[dict]: + return [{"text": t, "bbox": [0, 0, 100, 20]} for t in texts] + + return ocr_fn + + +# --------------------------------------------------------------------------- +# check_state_preconditions +# --------------------------------------------------------------------------- + + +class TestCheckStatePreconditions: + """Tests for check_state_preconditions().""" + + def test_no_ocr_returns_true(self): + """When no OCR function is provided, skip gracefully.""" + target = GroundingTarget( + window_title="Notepad", + nearby_text=["File", "Edit"], + ) + ok, reason = check_state_preconditions( + DUMMY_SCREENSHOT, target, ocr_fn=None + ) + assert ok is True + assert "no OCR available" in reason + + def test_no_expectations_returns_true(self): + """When target has no text expectations, nothing to check.""" + target = GroundingTarget(description="some button") + ok, reason = check_state_preconditions( + DUMMY_SCREENSHOT, target, ocr_fn=_make_ocr_fn([]) + ) + assert ok is True + assert "no text preconditions" in reason + + def test_with_ocr_window_title_match(self): + """Window title found on screen => preconditions met.""" + target = GroundingTarget(window_title="Notepad") + ocr = _make_ocr_fn(["Notepad - Untitled", "File", "Edit"]) + ok, reason = check_state_preconditions(DUMMY_SCREENSHOT, target, ocr) + assert ok is True + assert "preconditions met" in reason + + def test_with_ocr_window_title_mismatch(self): + """Window title NOT found on screen => preconditions fail.""" + target = GroundingTarget(window_title="Notepad") + ocr = _make_ocr_fn(["Chrome - Settings", "Privacy"]) + ok, reason = check_state_preconditions(DUMMY_SCREENSHOT, target, ocr) + assert ok is False + assert "window title mismatch" in reason + assert "Notepad" in reason + + def test_with_ocr_nearby_text_match(self): + """Enough nearby_text found => preconditions met.""" + target = GroundingTarget(nearby_text=["File", "Edit", "View", "Help"]) + ocr = _make_ocr_fn(["File", "Edit", "Format"]) + ok, reason = check_state_preconditions(DUMMY_SCREENSHOT, target, ocr) + assert ok is True # 2/4 found, threshold is max(1, 4//2) = 2 + + def test_with_ocr_nearby_text_mismatch(self): + """Not enough nearby_text found => preconditions fail.""" + target = GroundingTarget(nearby_text=["File", "Edit", "View", "Help"]) + ocr = _make_ocr_fn(["Settings", "Privacy", "Security"]) + ok, reason = check_state_preconditions(DUMMY_SCREENSHOT, target, ocr) + assert ok is False + assert "nearby text mismatch" in reason + + def test_with_ocr_surrounding_labels_match(self): + """Enough surrounding labels found.""" + target = GroundingTarget( + surrounding_labels=["OK", "Cancel", "Apply"] + ) + ocr = _make_ocr_fn(["OK", "Cancel", "Help"]) + ok, reason = check_state_preconditions(DUMMY_SCREENSHOT, target, ocr) + assert ok is True # 2/3 found, threshold = max(1, 3//2) = 1 + + def test_with_ocr_surrounding_labels_mismatch(self): + """Not enough surrounding labels found.""" + target = GroundingTarget( + surrounding_labels=["OK", "Cancel", "Apply"] + ) + ocr = _make_ocr_fn(["Settings", "Privacy"]) + ok, reason = check_state_preconditions(DUMMY_SCREENSHOT, target, ocr) + assert ok is False + assert "surrounding labels mismatch" in reason + + def test_case_insensitive(self): + """Text matching is case-insensitive by default.""" + target = GroundingTarget(window_title="Notepad") + ocr = _make_ocr_fn(["NOTEPAD - Untitled"]) + ok, reason = check_state_preconditions(DUMMY_SCREENSHOT, target, ocr) + assert ok is True + + def test_combined_checks_all_pass(self): + """Multiple checks all pass.""" + target = GroundingTarget( + window_title="Notepad", + nearby_text=["File", "Edit"], + surrounding_labels=["Format"], + ) + ocr = _make_ocr_fn(["Notepad", "File", "Edit", "Format", "Help"]) + ok, reason = check_state_preconditions(DUMMY_SCREENSHOT, target, ocr) + assert ok is True + + def test_combined_checks_title_fails(self): + """Window title fails even if other checks would pass.""" + target = GroundingTarget( + window_title="Notepad", + nearby_text=["File", "Edit"], + ) + ocr = _make_ocr_fn(["Chrome", "File", "Edit"]) + ok, reason = check_state_preconditions(DUMMY_SCREENSHOT, target, ocr) + assert ok is False + assert "window title mismatch" in reason + + +# --------------------------------------------------------------------------- +# verify_transition +# --------------------------------------------------------------------------- + + +class TestVerifyTransition: + """Tests for verify_transition().""" + + def test_no_expectations_returns_true(self): + """No structured transition expectations => pass.""" + target = GroundingTarget(description="some button") + ok, reason = verify_transition(DUMMY_SCREENSHOT, target, ocr_fn=None) + assert ok is True + assert "no transition expectations" in reason + + def test_no_ocr_returns_true(self): + """OCR unavailable => skip gracefully.""" + target = GroundingTarget(appearance_text=["Success"]) + ok, reason = verify_transition(DUMMY_SCREENSHOT, target, ocr_fn=None) + assert ok is True + assert "no OCR available" in reason + + def test_appearance_text_found(self): + """Expected text appeared after click.""" + target = GroundingTarget(appearance_text=["Confirmation dialog"]) + ocr = _make_ocr_fn(["Confirmation dialog", "OK", "Cancel"]) + ok, reason = verify_transition(DUMMY_SCREENSHOT, target, ocr) + assert ok is True + assert "transition verified" in reason + + def test_appearance_text_missing(self): + """Expected text did NOT appear.""" + target = GroundingTarget(appearance_text=["Confirmation dialog"]) + ocr = _make_ocr_fn(["Settings", "Privacy"]) + ok, reason = verify_transition(DUMMY_SCREENSHOT, target, ocr) + assert ok is False + assert "appearance_text not found" in reason + assert "Confirmation dialog" in reason + + def test_disappearance_text_gone(self): + """Text that should vanish is indeed gone.""" + target = GroundingTarget(disappearance_text=["Loading..."]) + ocr = _make_ocr_fn(["Ready", "File", "Edit"]) + ok, reason = verify_transition(DUMMY_SCREENSHOT, target, ocr) + assert ok is True + assert "transition verified" in reason + + def test_disappearance_text_still_present(self): + """Text that should vanish is still visible.""" + target = GroundingTarget(disappearance_text=["Loading..."]) + ocr = _make_ocr_fn(["Loading...", "Please wait"]) + ok, reason = verify_transition(DUMMY_SCREENSHOT, target, ocr) + assert ok is False + assert "disappearance_text still present" in reason + assert "Loading..." in reason + + def test_window_title_change_detected(self): + """New window title detected after click.""" + target = GroundingTarget(window_title_change="Settings") + ocr = _make_ocr_fn(["Settings - Chrome", "Privacy"]) + ok, reason = verify_transition(DUMMY_SCREENSHOT, target, ocr) + assert ok is True + + def test_window_title_change_not_detected(self): + """New window title NOT detected.""" + target = GroundingTarget(window_title_change="Settings") + ocr = _make_ocr_fn(["Notepad - Untitled"]) + ok, reason = verify_transition(DUMMY_SCREENSHOT, target, ocr) + assert ok is False + assert "window title change not detected" in reason + + def test_modal_toggled_skipped(self): + """modal_toggled is set but no detection backend — skips.""" + target = GroundingTarget( + modal_toggled=True, + appearance_text=["OK"], + ) + ocr = _make_ocr_fn(["OK", "Cancel"]) + ok, reason = verify_transition(DUMMY_SCREENSHOT, target, ocr) + assert ok is True # modal check skipped, appearance_text passes + + def test_combined_appearance_and_disappearance(self): + """Both appearance and disappearance expectations met.""" + target = GroundingTarget( + appearance_text=["Saved"], + disappearance_text=["Saving..."], + ) + ocr = _make_ocr_fn(["Saved", "File", "Edit"]) + ok, reason = verify_transition(DUMMY_SCREENSHOT, target, ocr) + assert ok is True + + def test_combined_appearance_passes_disappearance_fails(self): + """Appearance ok but disappearance still present.""" + target = GroundingTarget( + appearance_text=["Saved"], + disappearance_text=["Saving..."], + ) + ocr = _make_ocr_fn(["Saved", "Saving...", "File"]) + ok, reason = verify_transition(DUMMY_SCREENSHOT, target, ocr) + assert ok is False + assert "disappearance_text still present" in reason + + +# --------------------------------------------------------------------------- +# GroundingTarget round-trip serialization +# --------------------------------------------------------------------------- + + +class TestGroundingTargetRoundTrip: + """Tests for to_dict / from_dict serialization.""" + + def test_round_trip_preserves_all_fields(self): + """to_dict() -> from_dict() preserves every non-default field.""" + original = GroundingTarget( + description="Clear browsing data button", + target_type="button", + crop_path="crops/step_03.png", + crop_bbox=(100, 200, 300, 250), + click_offset=(50, 25), + nearby_text=["Clear data", "Browsing history"], + window_title="Chrome - Settings", + surrounding_labels=["Cookies", "Cached images"], + screenshot_before_path="screenshots/before_03.png", + screenshot_after_path="screenshots/after_03.png", + disappearance_text=["Clear browsing data"], + appearance_text=["Your data has been cleared"], + window_title_change="Chrome - New Tab", + region_changed=(50, 100, 400, 500), + modal_toggled=True, + expected_change="Confirmation dialog appears", + ) + + d = original.to_dict() + restored = GroundingTarget.from_dict(d) + + assert restored.description == original.description + assert restored.target_type == original.target_type + assert restored.crop_path == original.crop_path + assert restored.crop_bbox == original.crop_bbox + assert restored.click_offset == original.click_offset + assert restored.nearby_text == original.nearby_text + assert restored.window_title == original.window_title + assert restored.surrounding_labels == original.surrounding_labels + assert restored.screenshot_before_path == original.screenshot_before_path + assert restored.screenshot_after_path == original.screenshot_after_path + assert restored.disappearance_text == original.disappearance_text + assert restored.appearance_text == original.appearance_text + assert restored.window_title_change == original.window_title_change + assert restored.region_changed == original.region_changed + assert restored.modal_toggled == original.modal_toggled + assert restored.expected_change == original.expected_change + + def test_round_trip_tuple_from_list(self): + """JSON serializes tuples as lists; from_dict converts back.""" + d = { + "description": "test", + "crop_bbox": [10, 20, 30, 40], + "click_offset": [5, 5], + "region_changed": [0, 0, 100, 100], + } + restored = GroundingTarget.from_dict(d) + assert isinstance(restored.crop_bbox, tuple) + assert isinstance(restored.click_offset, tuple) + assert isinstance(restored.region_changed, tuple) + assert restored.crop_bbox == (10, 20, 30, 40) + + def test_round_trip_defaults_omitted(self): + """Default/empty fields are omitted from to_dict output.""" + target = GroundingTarget(description="test button") + d = target.to_dict() + assert "description" in d + # Empty/default fields should not be present + assert "nearby_text" not in d + assert "crop_bbox" not in d + assert "modal_toggled" not in d + assert "window_title" not in d + + def test_round_trip_minimal(self): + """A target with only defaults round-trips cleanly.""" + target = GroundingTarget() + d = target.to_dict() + assert d == {} # everything is default + restored = GroundingTarget.from_dict(d) + assert restored.description == "" + assert restored.nearby_text == [] + assert restored.crop_bbox is None