Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 48 additions & 11 deletions openadapt_evals/agents/demo_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
interpret them. The planner is only consulted as a recovery mechanism
when the expected screen state doesn't match.

Tier 1 (deterministic): keyboard shortcuts, typing execute directly.
Tier 2 (grounder-only): clicks grounder finds element by description.
Tier 3 (planner recovery): unexpected state planner reasons about
Tier 1 (deterministic): keyboard shortcuts, typing -- execute directly.
Tier 2 (grounder-only): clicks -- grounder finds element by description.
Tier 3 (planner recovery): unexpected state -- planner reasons about
what to do when the demo doesn't match reality.

Usage:
Expand All @@ -29,6 +29,7 @@

from openadapt_evals.adapters.base import BenchmarkAction, BenchmarkObservation
from openadapt_evals.demo_library import Demo, DemoStep
from openadapt_evals.grounding import check_state_preconditions, verify_transition

try:
from openadapt_evals.integrations.weave_integration import weave_op
Expand Down Expand Up @@ -94,7 +95,7 @@ def run(
screenshot_dir: Optional directory to save screenshots.

Returns:
(score, screenshots) score from evaluate_dense().
(score, screenshots) -- score from evaluate_dense().
"""
from openadapt_evals.adapters.rl_env import ResetConfig

Expand All @@ -121,13 +122,32 @@ def run(

for i, step in enumerate(demo.steps):
logger.info(
"Demo step %d/%d: %s %s %s",
"Demo step %d/%d: %s %s -- %s",
i + 1, len(demo.steps),
step.action_type,
step.action_value or "",
step.description,
)

# Phase 4: Pre-click state narrowing
if (
step.action_type in ("click", "double_click")
and step.grounding_target is not None
and obs.screenshot
):
ok, reason = check_state_preconditions(
obs.screenshot,
step.grounding_target,
ocr_fn=None,
)
if not ok:
logger.warning(
"Step %d: state precondition failed: %s",
i + 1, reason,
)
# Observational in Phase 4 -- proceed anyway.
# Blocking / state recovery deferred to later phase.

action = self._execute_step(step, obs)
if action is None:
logger.warning("Step %d: no action produced, skipping", i + 1)
Expand All @@ -144,6 +164,23 @@ def run(
step_result = self._dispatch_action(env, action)
obs = step_result.observation

# Phase 4: Post-click transition verification
if (
step.action_type in ("click", "double_click")
and step.grounding_target is not None
and obs.screenshot
):
ok, reason = verify_transition(
obs.screenshot,
step.grounding_target,
ocr_fn=None,
)
if not ok:
logger.warning(
"Step %d: transition verification failed: %s",
i + 1, reason,
)

if obs.screenshot:
screenshots.append(obs.screenshot)
if screenshot_dir:
Expand Down Expand Up @@ -190,9 +227,9 @@ def _execute_step(
) -> BenchmarkAction | None:
"""Produce an action for a demo step using tiered intelligence.

Tier 1: keyboard/type direct execution (no VLM).
Tier 2: click grounder finds element by description.
Tier 3: recovery planner reasons about unexpected state.
Tier 1: keyboard/type -> direct execution (no VLM).
Tier 2: click -> grounder finds element by description.
Tier 3: recovery -> planner reasons about unexpected state.
"""
if step.action_type == "key":
# Tier 1: deterministic keyboard action
Expand Down Expand Up @@ -230,7 +267,7 @@ def _execute_step(
)
return action

# Unknown action type log and skip
# Unknown action type -- log and skip
logger.warning("Unknown action type %r, skipping", step.action_type)
return None

Expand Down Expand Up @@ -306,7 +343,7 @@ def _ground_click_http(

logger.info("HTTP grounder: %s", raw[:200])

# Parse [x1,y1,x2,y2] bbox center click
# Parse [x1,y1,x2,y2] bbox -> center click
from openadapt_evals.agents.planner_grounder_agent import (
PlannerGrounderAgent,
)
Expand Down Expand Up @@ -345,7 +382,7 @@ def _ground_click_vlm(

if action.type == "done":
logger.warning(
"Grounder could not find %r returning click at center",
"Grounder could not find %r -- returning click at center",
description,
)
return BenchmarkAction(type="click", x=0.5, y=0.5)
Expand Down
210 changes: 208 additions & 2 deletions openadapt_evals/grounding.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,24 @@
"""Grounding data model for the DemoExecutor cascade.
"""Grounding data model and state verification for the DemoExecutor cascade.

Defines GroundingTarget (stored per click step in demo) and
GroundingCandidate (produced by each tier during grounding).

Also provides state-narrowing functions (Phase 4 of the cascade):
- ``check_state_preconditions``: verify the screen matches expectations
before grounding a click.
- ``verify_transition``: verify the expected state change occurred after
clicking.

See docs/design/grounding_cascade_design_v3.md for the full architecture.
"""

from __future__ import annotations

import logging
from dataclasses import dataclass, field
from typing import Any
from typing import Any, Callable

logger = logging.getLogger(__name__)


@dataclass
Expand Down Expand Up @@ -87,3 +96,200 @@ class GroundingCandidate:
spatial_score: float | None = None # consistency with demo position
visual_verify_score: float | None = None # crop resemblance to target
accepted: bool = False


# ---------------------------------------------------------------------------
# Phase 4: State narrowing -- pre-click and post-click verification
# ---------------------------------------------------------------------------


def _text_present(
query: str,
ocr_results: list[dict],
case_sensitive: bool = False,
) -> bool:
"""Check whether *query* appears in any OCR result text.

Args:
query: Text to search for.
ocr_results: List of dicts with at least a ``"text"`` key.
case_sensitive: Whether the comparison is case-sensitive.

Returns:
``True`` if *query* is a substring of any OCR result text.
"""
if not case_sensitive:
query = query.lower()
for item in ocr_results:
text = item.get("text", "")
if not case_sensitive:
text = text.lower()
if query in text:
return True
return False


def check_state_preconditions(
screenshot: bytes,
target: GroundingTarget,
ocr_fn: Callable[[bytes], list[dict]] | None = None,
) -> tuple[bool, str]:
"""Check if the current screen state matches the demo's expectations.

This is the "state narrowing" step that runs *before* candidate
generation. It is cheaper to detect "wrong screen" than to ground
on it -- see the Phase 4 rationale in
``docs/design/grounding_cascade_design_v3.md``.

Args:
screenshot: Current screenshot PNG bytes.
target: :class:`GroundingTarget` with ``window_title``,
``nearby_text``, ``surrounding_labels``, etc.
ocr_fn: Optional OCR function that accepts PNG bytes and returns
``list[dict]`` where each dict has at least a ``"text"`` key
(and optionally ``"bbox"``). When *None*, precondition
checks that require OCR are skipped gracefully.

Returns:
``(preconditions_met, reason)`` -- ``True`` if safe to proceed
with grounding, ``False`` with a human-readable reason string if
state recovery is needed.
"""
has_expectations = bool(
target.window_title
or target.nearby_text
or target.surrounding_labels
)

# No text expectations on this target -- nothing to check.
if not has_expectations:
return True, "no text preconditions defined on target"

# OCR unavailable -- skip gracefully (Phase 5 adds real OCR).
if ocr_fn is None:
return True, "no OCR available, skipping precondition check"

ocr_results = ocr_fn(screenshot)

# 1. Window title check
if target.window_title:
if not _text_present(target.window_title, ocr_results):
return (
False,
f"window title mismatch: expected {target.window_title!r}",
)

# 2. Nearby text -- require at least half to be present
if target.nearby_text:
found = sum(
1 for t in target.nearby_text if _text_present(t, ocr_results)
)
threshold = max(1, len(target.nearby_text) // 2)
if found < threshold:
return (
False,
f"nearby text mismatch: found {found}/{len(target.nearby_text)}"
f" (need >= {threshold})",
)

# 3. Surrounding labels -- require at least half to be present
if target.surrounding_labels:
found = sum(
1
for t in target.surrounding_labels
if _text_present(t, ocr_results)
)
threshold = max(1, len(target.surrounding_labels) // 2)
if found < threshold:
return (
False,
f"surrounding labels mismatch: found "
f"{found}/{len(target.surrounding_labels)}"
f" (need >= {threshold})",
)

return True, "preconditions met"


def verify_transition(
screenshot_after: bytes,
target: GroundingTarget,
ocr_fn: Callable[[bytes], list[dict]] | None = None,
) -> tuple[bool, str]:
"""Verify that the click produced the expected state change.

Uses structured transition expectations from :class:`GroundingTarget`:

- ``disappearance_text``: text that should *no longer* be visible.
- ``appearance_text``: text that should *now* be visible.
- ``window_title_change``: expected new window title.
- ``modal_toggled``: whether a modal appeared/disappeared (deferred
until a modal-detection backend is available).

Args:
screenshot_after: Screenshot PNG bytes taken after the click.
target: :class:`GroundingTarget` with structured transition
expectations.
ocr_fn: Optional OCR function (same contract as
:func:`check_state_preconditions`). When *None*, checks
that require OCR are skipped gracefully.

Returns:
``(verified, reason)`` -- ``True`` if the transition looks
correct, ``False`` with a human-readable reason if it looks
wrong.
"""
has_expectations = bool(
target.disappearance_text
or target.appearance_text
or target.window_title_change is not None
or target.modal_toggled is not None
)

# No structured transition expectations -- nothing to verify.
if not has_expectations:
return True, "no transition expectations defined on target"

# OCR unavailable -- skip gracefully.
if ocr_fn is None:
return True, "no OCR available, skipping transition verification"

ocr_results = ocr_fn(screenshot_after)

# 1. Disappearance check -- text should have vanished.
if target.disappearance_text:
for text in target.disappearance_text:
if _text_present(text, ocr_results):
return (
False,
f"disappearance_text still present: {text!r}",
)

# 2. Appearance check -- text should now be visible.
if target.appearance_text:
for text in target.appearance_text:
if not _text_present(text, ocr_results):
return (
False,
f"appearance_text not found: {text!r}",
)

# 3. Window title change
if target.window_title_change is not None:
if not _text_present(target.window_title_change, ocr_results):
return (
False,
f"window title change not detected: "
f"expected {target.window_title_change!r}",
)

# 4. Modal toggled -- deferred (requires modal detection backend).
# Log for observability but do not fail.
if target.modal_toggled is not None:
logger.debug(
"modal_toggled=%s expectation set but no modal detection "
"backend available -- skipping",
target.modal_toggled,
)

return True, "transition verified"
Loading
Loading