Skip to content

Commit e22b404

Browse files
abrichrclaude
andauthored
feat: state narrowing and transition verification for grounding cascade (#257)
Phase 4 of the grounding cascade — detect "wrong screen" before grounding and verify state changes after clicking. Added to grounding.py: - check_state_preconditions(): verifies window title, nearby text, and surrounding labels match expectations before grounding a click. Skips gracefully when no OCR function is provided (Phase 5). - verify_transition(): checks disappearance_text, appearance_text, and window_title_change against post-click screenshot via OCR. Modal detection deferred (logged, not enforced). - _text_present(): case-insensitive substring matching helper. Integrated into DemoExecutor.run(): - Pre-click: calls check_state_preconditions for click/double_click steps with a grounding_target. Observational only (warns, proceeds). - Post-click: calls verify_transition after action dispatch. Observational only (warns, proceeds). Tests (26 new): - 11 tests for check_state_preconditions (no-OCR, no-expectations, window title match/mismatch, nearby text, surrounding labels, case insensitivity, combined checks) - 11 tests for verify_transition (no-expectations, no-OCR, appearance/disappearance, window title change, modal skip, combined scenarios) - 4 tests for GroundingTarget round-trip serialization Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 68e4b2a commit e22b404

File tree

3 files changed

+585
-13
lines changed

3 files changed

+585
-13
lines changed

openadapt_evals/agents/demo_executor.py

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
interpret them. The planner is only consulted as a recovery mechanism
55
when the expected screen state doesn't match.
66
7-
Tier 1 (deterministic): keyboard shortcuts, typing execute directly.
8-
Tier 2 (grounder-only): clicks grounder finds element by description.
9-
Tier 3 (planner recovery): unexpected state planner reasons about
7+
Tier 1 (deterministic): keyboard shortcuts, typing -- execute directly.
8+
Tier 2 (grounder-only): clicks -- grounder finds element by description.
9+
Tier 3 (planner recovery): unexpected state -- planner reasons about
1010
what to do when the demo doesn't match reality.
1111
1212
Usage:
@@ -29,6 +29,7 @@
2929

3030
from openadapt_evals.adapters.base import BenchmarkAction, BenchmarkObservation
3131
from openadapt_evals.demo_library import Demo, DemoStep
32+
from openadapt_evals.grounding import check_state_preconditions, verify_transition
3233

3334
try:
3435
from openadapt_evals.integrations.weave_integration import weave_op
@@ -94,7 +95,7 @@ def run(
9495
screenshot_dir: Optional directory to save screenshots.
9596
9697
Returns:
97-
(score, screenshots) score from evaluate_dense().
98+
(score, screenshots) -- score from evaluate_dense().
9899
"""
99100
from openadapt_evals.adapters.rl_env import ResetConfig
100101

@@ -121,13 +122,32 @@ def run(
121122

122123
for i, step in enumerate(demo.steps):
123124
logger.info(
124-
"Demo step %d/%d: %s %s %s",
125+
"Demo step %d/%d: %s %s -- %s",
125126
i + 1, len(demo.steps),
126127
step.action_type,
127128
step.action_value or "",
128129
step.description,
129130
)
130131

132+
# Phase 4: Pre-click state narrowing
133+
if (
134+
step.action_type in ("click", "double_click")
135+
and step.grounding_target is not None
136+
and obs.screenshot
137+
):
138+
ok, reason = check_state_preconditions(
139+
obs.screenshot,
140+
step.grounding_target,
141+
ocr_fn=None,
142+
)
143+
if not ok:
144+
logger.warning(
145+
"Step %d: state precondition failed: %s",
146+
i + 1, reason,
147+
)
148+
# Observational in Phase 4 -- proceed anyway.
149+
# Blocking / state recovery deferred to later phase.
150+
131151
action = self._execute_step(step, obs)
132152
if action is None:
133153
logger.warning("Step %d: no action produced, skipping", i + 1)
@@ -144,6 +164,23 @@ def run(
144164
step_result = self._dispatch_action(env, action)
145165
obs = step_result.observation
146166

167+
# Phase 4: Post-click transition verification
168+
if (
169+
step.action_type in ("click", "double_click")
170+
and step.grounding_target is not None
171+
and obs.screenshot
172+
):
173+
ok, reason = verify_transition(
174+
obs.screenshot,
175+
step.grounding_target,
176+
ocr_fn=None,
177+
)
178+
if not ok:
179+
logger.warning(
180+
"Step %d: transition verification failed: %s",
181+
i + 1, reason,
182+
)
183+
147184
if obs.screenshot:
148185
screenshots.append(obs.screenshot)
149186
if screenshot_dir:
@@ -190,9 +227,9 @@ def _execute_step(
190227
) -> BenchmarkAction | None:
191228
"""Produce an action for a demo step using tiered intelligence.
192229
193-
Tier 1: keyboard/type direct execution (no VLM).
194-
Tier 2: click grounder finds element by description.
195-
Tier 3: recovery planner reasons about unexpected state.
230+
Tier 1: keyboard/type -> direct execution (no VLM).
231+
Tier 2: click -> grounder finds element by description.
232+
Tier 3: recovery -> planner reasons about unexpected state.
196233
"""
197234
if step.action_type == "key":
198235
# Tier 1: deterministic keyboard action
@@ -230,7 +267,7 @@ def _execute_step(
230267
)
231268
return action
232269

233-
# Unknown action type log and skip
270+
# Unknown action type -- log and skip
234271
logger.warning("Unknown action type %r, skipping", step.action_type)
235272
return None
236273

@@ -306,7 +343,7 @@ def _ground_click_http(
306343

307344
logger.info("HTTP grounder: %s", raw[:200])
308345

309-
# Parse [x1,y1,x2,y2] bbox center click
346+
# Parse [x1,y1,x2,y2] bbox -> center click
310347
from openadapt_evals.agents.planner_grounder_agent import (
311348
PlannerGrounderAgent,
312349
)
@@ -345,7 +382,7 @@ def _ground_click_vlm(
345382

346383
if action.type == "done":
347384
logger.warning(
348-
"Grounder could not find %r returning click at center",
385+
"Grounder could not find %r -- returning click at center",
349386
description,
350387
)
351388
return BenchmarkAction(type="click", x=0.5, y=0.5)

openadapt_evals/grounding.py

Lines changed: 208 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,24 @@
1-
"""Grounding data model for the DemoExecutor cascade.
1+
"""Grounding data model and state verification for the DemoExecutor cascade.
22
33
Defines GroundingTarget (stored per click step in demo) and
44
GroundingCandidate (produced by each tier during grounding).
55
6+
Also provides state-narrowing functions (Phase 4 of the cascade):
7+
- ``check_state_preconditions``: verify the screen matches expectations
8+
before grounding a click.
9+
- ``verify_transition``: verify the expected state change occurred after
10+
clicking.
11+
612
See docs/design/grounding_cascade_design_v3.md for the full architecture.
713
"""
814

915
from __future__ import annotations
1016

17+
import logging
1118
from dataclasses import dataclass, field
12-
from typing import Any
19+
from typing import Any, Callable
20+
21+
logger = logging.getLogger(__name__)
1322

1423

1524
@dataclass
@@ -87,3 +96,200 @@ class GroundingCandidate:
8796
spatial_score: float | None = None # consistency with demo position
8897
visual_verify_score: float | None = None # crop resemblance to target
8998
accepted: bool = False
99+
100+
101+
# ---------------------------------------------------------------------------
102+
# Phase 4: State narrowing -- pre-click and post-click verification
103+
# ---------------------------------------------------------------------------
104+
105+
106+
def _text_present(
107+
query: str,
108+
ocr_results: list[dict],
109+
case_sensitive: bool = False,
110+
) -> bool:
111+
"""Check whether *query* appears in any OCR result text.
112+
113+
Args:
114+
query: Text to search for.
115+
ocr_results: List of dicts with at least a ``"text"`` key.
116+
case_sensitive: Whether the comparison is case-sensitive.
117+
118+
Returns:
119+
``True`` if *query* is a substring of any OCR result text.
120+
"""
121+
if not case_sensitive:
122+
query = query.lower()
123+
for item in ocr_results:
124+
text = item.get("text", "")
125+
if not case_sensitive:
126+
text = text.lower()
127+
if query in text:
128+
return True
129+
return False
130+
131+
132+
def check_state_preconditions(
133+
screenshot: bytes,
134+
target: GroundingTarget,
135+
ocr_fn: Callable[[bytes], list[dict]] | None = None,
136+
) -> tuple[bool, str]:
137+
"""Check if the current screen state matches the demo's expectations.
138+
139+
This is the "state narrowing" step that runs *before* candidate
140+
generation. It is cheaper to detect "wrong screen" than to ground
141+
on it -- see the Phase 4 rationale in
142+
``docs/design/grounding_cascade_design_v3.md``.
143+
144+
Args:
145+
screenshot: Current screenshot PNG bytes.
146+
target: :class:`GroundingTarget` with ``window_title``,
147+
``nearby_text``, ``surrounding_labels``, etc.
148+
ocr_fn: Optional OCR function that accepts PNG bytes and returns
149+
``list[dict]`` where each dict has at least a ``"text"`` key
150+
(and optionally ``"bbox"``). When *None*, precondition
151+
checks that require OCR are skipped gracefully.
152+
153+
Returns:
154+
``(preconditions_met, reason)`` -- ``True`` if safe to proceed
155+
with grounding, ``False`` with a human-readable reason string if
156+
state recovery is needed.
157+
"""
158+
has_expectations = bool(
159+
target.window_title
160+
or target.nearby_text
161+
or target.surrounding_labels
162+
)
163+
164+
# No text expectations on this target -- nothing to check.
165+
if not has_expectations:
166+
return True, "no text preconditions defined on target"
167+
168+
# OCR unavailable -- skip gracefully (Phase 5 adds real OCR).
169+
if ocr_fn is None:
170+
return True, "no OCR available, skipping precondition check"
171+
172+
ocr_results = ocr_fn(screenshot)
173+
174+
# 1. Window title check
175+
if target.window_title:
176+
if not _text_present(target.window_title, ocr_results):
177+
return (
178+
False,
179+
f"window title mismatch: expected {target.window_title!r}",
180+
)
181+
182+
# 2. Nearby text -- require at least half to be present
183+
if target.nearby_text:
184+
found = sum(
185+
1 for t in target.nearby_text if _text_present(t, ocr_results)
186+
)
187+
threshold = max(1, len(target.nearby_text) // 2)
188+
if found < threshold:
189+
return (
190+
False,
191+
f"nearby text mismatch: found {found}/{len(target.nearby_text)}"
192+
f" (need >= {threshold})",
193+
)
194+
195+
# 3. Surrounding labels -- require at least half to be present
196+
if target.surrounding_labels:
197+
found = sum(
198+
1
199+
for t in target.surrounding_labels
200+
if _text_present(t, ocr_results)
201+
)
202+
threshold = max(1, len(target.surrounding_labels) // 2)
203+
if found < threshold:
204+
return (
205+
False,
206+
f"surrounding labels mismatch: found "
207+
f"{found}/{len(target.surrounding_labels)}"
208+
f" (need >= {threshold})",
209+
)
210+
211+
return True, "preconditions met"
212+
213+
214+
def verify_transition(
215+
screenshot_after: bytes,
216+
target: GroundingTarget,
217+
ocr_fn: Callable[[bytes], list[dict]] | None = None,
218+
) -> tuple[bool, str]:
219+
"""Verify that the click produced the expected state change.
220+
221+
Uses structured transition expectations from :class:`GroundingTarget`:
222+
223+
- ``disappearance_text``: text that should *no longer* be visible.
224+
- ``appearance_text``: text that should *now* be visible.
225+
- ``window_title_change``: expected new window title.
226+
- ``modal_toggled``: whether a modal appeared/disappeared (deferred
227+
until a modal-detection backend is available).
228+
229+
Args:
230+
screenshot_after: Screenshot PNG bytes taken after the click.
231+
target: :class:`GroundingTarget` with structured transition
232+
expectations.
233+
ocr_fn: Optional OCR function (same contract as
234+
:func:`check_state_preconditions`). When *None*, checks
235+
that require OCR are skipped gracefully.
236+
237+
Returns:
238+
``(verified, reason)`` -- ``True`` if the transition looks
239+
correct, ``False`` with a human-readable reason if it looks
240+
wrong.
241+
"""
242+
has_expectations = bool(
243+
target.disappearance_text
244+
or target.appearance_text
245+
or target.window_title_change is not None
246+
or target.modal_toggled is not None
247+
)
248+
249+
# No structured transition expectations -- nothing to verify.
250+
if not has_expectations:
251+
return True, "no transition expectations defined on target"
252+
253+
# OCR unavailable -- skip gracefully.
254+
if ocr_fn is None:
255+
return True, "no OCR available, skipping transition verification"
256+
257+
ocr_results = ocr_fn(screenshot_after)
258+
259+
# 1. Disappearance check -- text should have vanished.
260+
if target.disappearance_text:
261+
for text in target.disappearance_text:
262+
if _text_present(text, ocr_results):
263+
return (
264+
False,
265+
f"disappearance_text still present: {text!r}",
266+
)
267+
268+
# 2. Appearance check -- text should now be visible.
269+
if target.appearance_text:
270+
for text in target.appearance_text:
271+
if not _text_present(text, ocr_results):
272+
return (
273+
False,
274+
f"appearance_text not found: {text!r}",
275+
)
276+
277+
# 3. Window title change
278+
if target.window_title_change is not None:
279+
if not _text_present(target.window_title_change, ocr_results):
280+
return (
281+
False,
282+
f"window title change not detected: "
283+
f"expected {target.window_title_change!r}",
284+
)
285+
286+
# 4. Modal toggled -- deferred (requires modal detection backend).
287+
# Log for observability but do not fail.
288+
if target.modal_toggled is not None:
289+
logger.debug(
290+
"modal_toggled=%s expectation set but no modal detection "
291+
"backend available -- skipping",
292+
target.modal_toggled,
293+
)
294+
295+
return True, "transition verified"

0 commit comments

Comments
 (0)