Skip to content

Commit aa797dd

Browse files
abrichrclaude
andauthored
feat: OCR text anchoring (Tier 1.5a) for grounding cascade (#259)
Add Phase 5 text anchoring on top of Phase 4 state narrowing: - grounding.py: run_ocr() with pytesseract (optional dep, graceful fallback), ground_by_text() with tiered scoring (exact/case-insensitive/ substring/fuzzy) and nearby-text proximity boost, plus helper functions _char_overlap_ratio, _bbox_center, _bbox_distance. - demo_executor.py: _try_text_anchoring() method inserted before VLM grounder calls for click/double_click actions. Returns action if best OCR candidate scores > 0.85, otherwise falls through to Tier 2. - tests/test_text_anchoring.py: 21 tests covering all scoring tiers, proximity boost, edge cases, and graceful pytesseract fallback. All tests use mocked OCR results (no pytesseract required). Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent a5ebabb commit aa797dd

3 files changed

Lines changed: 588 additions & 3 deletions

File tree

openadapt_evals/agents/demo_executor.py

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,13 @@
2929

3030
from openadapt_evals.adapters.base import BenchmarkAction, BenchmarkObservation
3131
from openadapt_evals.demo_library import Demo, DemoStep
32-
from openadapt_evals.grounding import check_state_preconditions, verify_transition
32+
from openadapt_evals.grounding import (
33+
GroundingTarget,
34+
check_state_preconditions,
35+
ground_by_text,
36+
run_ocr,
37+
verify_transition,
38+
)
3339

3440
try:
3541
from openadapt_evals.integrations.weave_integration import weave_op
@@ -219,6 +225,79 @@ def run(
219225

220226
return score, screenshots
221227

228+
def _try_text_anchoring(
229+
self,
230+
screenshot: bytes,
231+
step: DemoStep,
232+
) -> BenchmarkAction | None:
233+
"""Attempt to ground a click via OCR text anchoring (Tier 1.5a).
234+
235+
Creates a :class:`GroundingTarget` from the step and runs OCR-based
236+
text matching. If the best candidate scores above ``0.85``, returns
237+
a click action at those coordinates. Otherwise returns ``None`` so
238+
the caller falls through to the VLM grounder (Tier 2).
239+
240+
Args:
241+
screenshot: Current screenshot PNG bytes.
242+
step: The demo step being executed.
243+
244+
Returns:
245+
A :class:`BenchmarkAction` if text anchoring succeeds with high
246+
confidence, or ``None`` to fall through.
247+
"""
248+
# Build target from step's grounding_target or description
249+
if step.grounding_target is not None and isinstance(
250+
step.grounding_target, GroundingTarget
251+
):
252+
target = step.grounding_target
253+
else:
254+
description = step.description or step.target_description
255+
if not description:
256+
description = step.action_description
257+
if not description:
258+
return None
259+
target = GroundingTarget(description=description)
260+
261+
if not target.description:
262+
return None
263+
264+
# Run OCR and text grounding
265+
ocr_results = run_ocr(screenshot)
266+
if not ocr_results:
267+
logger.debug("Tier 1.5a: no OCR results, falling through to VLM")
268+
return None
269+
270+
candidates = ground_by_text(screenshot, target, ocr_results=ocr_results)
271+
if not candidates:
272+
logger.debug(
273+
"Tier 1.5a: no text matches for %r, falling through to VLM",
274+
target.description,
275+
)
276+
return None
277+
278+
best = candidates[0]
279+
if best.local_score > 0.85:
280+
logger.info(
281+
"Tier 1.5a (text anchor): %r matched %r at %s (score=%.2f)",
282+
target.description,
283+
best.matched_text,
284+
best.point,
285+
best.local_score,
286+
)
287+
return BenchmarkAction(
288+
type="click",
289+
x=best.point[0],
290+
y=best.point[1],
291+
raw_action={"tier": 1.5, "source": "ocr_text_anchor"},
292+
)
293+
294+
logger.debug(
295+
"Tier 1.5a: best score %.2f < 0.85 for %r, falling through",
296+
best.local_score,
297+
target.description,
298+
)
299+
return None
300+
222301
@weave_op
223302
def _execute_step(
224303
self,
@@ -228,6 +307,7 @@ def _execute_step(
228307
"""Produce an action for a demo step using tiered intelligence.
229308
230309
Tier 1: keyboard/type -> direct execution (no VLM).
310+
Tier 1.5a: click -> OCR text anchoring (cheap, no VLM).
231311
Tier 2: click -> grounder finds element by description.
232312
Tier 3: recovery -> planner reasons about unexpected state.
233313
"""
@@ -250,6 +330,12 @@ def _execute_step(
250330
return BenchmarkAction(type="type", text=text, raw_action={"tier": 1})
251331

252332
if step.action_type == "click":
333+
# Tier 1.5a: try OCR text anchoring first
334+
if obs.screenshot:
335+
text_action = self._try_text_anchoring(obs.screenshot, step)
336+
if text_action is not None:
337+
return text_action
338+
253339
# Tier 2: grounder finds element by description
254340
description = step.description or step.target_description
255341
if not description:
@@ -258,6 +344,17 @@ def _execute_step(
258344
return self._ground_click(obs, description)
259345

260346
if step.action_type == "double_click":
347+
# Tier 1.5a: try OCR text anchoring first
348+
if obs.screenshot:
349+
text_action = self._try_text_anchoring(obs.screenshot, step)
350+
if text_action is not None:
351+
return BenchmarkAction(
352+
type="double_click",
353+
x=text_action.x,
354+
y=text_action.y,
355+
raw_action=text_action.raw_action,
356+
)
357+
261358
description = step.description or step.target_description
262359
logger.info("Tier 2 (grounder): double-click %s", description)
263360
action = self._ground_click(obs, description)

openadapt_evals/grounding.py

Lines changed: 216 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
1-
"""Grounding data model and state verification for the DemoExecutor cascade.
1+
"""Grounding data model, state verification, and text anchoring for the
2+
DemoExecutor cascade.
23
34
Defines GroundingTarget (stored per click step in demo) and
45
GroundingCandidate (produced by each tier during grounding).
56
6-
Also provides state-narrowing functions (Phase 4 of the cascade):
7+
Phase 4 — state-narrowing functions:
78
- ``check_state_preconditions``: verify the screen matches expectations
89
before grounding a click.
910
- ``verify_transition``: verify the expected state change occurred after
1011
clicking.
1112
13+
Phase 5 — OCR text anchoring (Tier 1.5a):
14+
- ``run_ocr``: extract text regions from a screenshot via pytesseract.
15+
- ``ground_by_text``: match a GroundingTarget against OCR text with
16+
tiered scoring (exact > case-insensitive > substring > fuzzy) and
17+
nearby-text proximity boosting.
18+
1219
See docs/design/grounding_cascade_design_v3.md for the full architecture.
1320
"""
1421

@@ -293,3 +300,210 @@ def verify_transition(
293300
)
294301

295302
return True, "transition verified"
303+
304+
305+
# ---------------------------------------------------------------------------
306+
# Phase 5: OCR text anchoring (Tier 1.5a)
307+
# ---------------------------------------------------------------------------
308+
309+
310+
def _char_overlap_ratio(a: str, b: str) -> float:
311+
"""Return the ratio of shared characters between *a* and *b*.
312+
313+
Uses character-level intersection (multiset) divided by the length of
314+
the longer string. This is *not* edit distance — it is deliberately
315+
cheap and order-insensitive.
316+
317+
Returns:
318+
A float in ``[0.0, 1.0]``.
319+
"""
320+
if not a or not b:
321+
return 0.0
322+
# Build character frequency maps
323+
from collections import Counter
324+
325+
ca = Counter(a.lower())
326+
cb = Counter(b.lower())
327+
overlap = sum((ca & cb).values())
328+
return overlap / max(len(a), len(b))
329+
330+
331+
def _bbox_center(bbox: list[int] | tuple[int, ...]) -> tuple[float, float]:
332+
"""Return the center ``(cx, cy)`` of an ``[x1, y1, x2, y2]`` bbox."""
333+
x1, y1, x2, y2 = bbox[:4]
334+
return ((x1 + x2) / 2.0, (y1 + y2) / 2.0)
335+
336+
337+
def _bbox_distance(
338+
a: list[int] | tuple[int, ...],
339+
b: list[int] | tuple[int, ...],
340+
) -> float:
341+
"""Euclidean distance between the centers of two bboxes."""
342+
import math
343+
344+
ax, ay = _bbox_center(a)
345+
bx, by = _bbox_center(b)
346+
return math.sqrt((ax - bx) ** 2 + (ay - by) ** 2)
347+
348+
349+
def run_ocr(screenshot: bytes) -> list[dict]:
350+
"""Run OCR on a screenshot and return detected text regions.
351+
352+
Uses ``pytesseract`` when available. If it is not installed, returns
353+
an empty list (graceful degradation — callers must handle ``[]``).
354+
355+
Args:
356+
screenshot: PNG image bytes.
357+
358+
Returns:
359+
List of dicts with keys ``"text"``, ``"bbox"`` (``[x1, y1, x2, y2]``),
360+
and ``"confidence"`` (``0.0``–``1.0``).
361+
"""
362+
try:
363+
import pytesseract # type: ignore[import-untyped]
364+
except ImportError:
365+
logger.debug("pytesseract not installed — returning empty OCR results")
366+
return []
367+
368+
try:
369+
from PIL import Image
370+
import io
371+
372+
image = Image.open(io.BytesIO(screenshot))
373+
data = pytesseract.image_to_data(image, output_type=pytesseract.Output.DICT)
374+
except Exception as exc:
375+
logger.warning("OCR failed: %s", exc)
376+
return []
377+
378+
results: list[dict] = []
379+
n_boxes = len(data.get("text", []))
380+
for i in range(n_boxes):
381+
text = data["text"][i].strip()
382+
if not text:
383+
continue
384+
conf = float(data["conf"][i])
385+
if conf < 0:
386+
continue
387+
x = int(data["left"][i])
388+
y = int(data["top"][i])
389+
w = int(data["width"][i])
390+
h = int(data["height"][i])
391+
results.append({
392+
"text": text,
393+
"bbox": [x, y, x + w, y + h],
394+
"confidence": conf / 100.0,
395+
})
396+
return results
397+
398+
399+
def ground_by_text(
400+
screenshot: bytes,
401+
target: GroundingTarget,
402+
ocr_results: list[dict] | None = None,
403+
) -> list[GroundingCandidate]:
404+
"""Ground a target by matching its description against OCR text.
405+
406+
This is **Tier 1.5a** in the grounding cascade — faster and cheaper
407+
than a VLM call, but only works when the target contains readable
408+
text.
409+
410+
Scoring tiers (from highest to lowest):
411+
412+
- **Exact match** (``0.95``): OCR text equals the target description.
413+
- **Case-insensitive match** (``0.90``): Matches after lowercasing.
414+
- **Substring match** (``0.70``): Target description is a substring of
415+
the OCR text (or vice-versa), case-insensitive.
416+
- **Fuzzy match** (``0.50``): Character-level overlap ratio > 80%.
417+
418+
Candidates near ``target.nearby_text`` locations receive a ``+0.05``
419+
proximity boost (capped at ``1.0``).
420+
421+
Args:
422+
screenshot: PNG image bytes (used for OCR if *ocr_results* not
423+
provided).
424+
target: :class:`GroundingTarget` with at least a ``description``.
425+
ocr_results: Pre-computed OCR results. When ``None``,
426+
:func:`run_ocr` is called on *screenshot*.
427+
428+
Returns:
429+
Up to 5 :class:`GroundingCandidate` objects sorted by score
430+
(highest first). Empty list if no matches found.
431+
"""
432+
if not target.description:
433+
return []
434+
435+
if ocr_results is None:
436+
ocr_results = run_ocr(screenshot)
437+
438+
if not ocr_results:
439+
return []
440+
441+
query = target.description
442+
query_lower = query.lower()
443+
444+
candidates: list[GroundingCandidate] = []
445+
446+
for item in ocr_results:
447+
text = item.get("text", "")
448+
bbox = item.get("bbox")
449+
if not text or not bbox:
450+
continue
451+
452+
text_lower = text.lower()
453+
score = 0.0
454+
455+
# Tiered scoring
456+
if text == query:
457+
score = 0.95
458+
elif text_lower == query_lower:
459+
score = 0.90
460+
elif query_lower in text_lower or text_lower in query_lower:
461+
score = 0.70
462+
elif _char_overlap_ratio(query, text) > 0.80:
463+
score = 0.50
464+
else:
465+
continue # No match
466+
467+
cx, cy = _bbox_center(bbox)
468+
candidates.append(
469+
GroundingCandidate(
470+
source="ocr",
471+
point=(int(cx), int(cy)),
472+
bbox=tuple(bbox[:4]), # type: ignore[arg-type]
473+
local_score=score,
474+
matched_text=text,
475+
reasoning=f"OCR text match: {text!r} (score={score:.2f})",
476+
)
477+
)
478+
479+
# Proximity boost: +0.05 for candidates near nearby_text locations
480+
if target.nearby_text and candidates:
481+
# Find bboxes for nearby_text items
482+
nearby_bboxes: list[list[int]] = []
483+
for nearby in target.nearby_text:
484+
nearby_lower = nearby.lower()
485+
for item in ocr_results:
486+
item_text = item.get("text", "").lower()
487+
if nearby_lower in item_text and item.get("bbox"):
488+
nearby_bboxes.append(item["bbox"])
489+
490+
if nearby_bboxes:
491+
proximity_threshold = 300.0 # pixels
492+
for candidate in candidates:
493+
if candidate.bbox is None:
494+
continue
495+
for nb_bbox in nearby_bboxes:
496+
dist = _bbox_distance(list(candidate.bbox), nb_bbox)
497+
if dist < proximity_threshold:
498+
candidate.local_score = min(
499+
1.0, candidate.local_score + 0.05
500+
)
501+
candidate.reasoning = (
502+
f"{candidate.reasoning}, "
503+
f"nearby boost (+0.05)"
504+
)
505+
break # One boost per candidate
506+
507+
# Sort by score (descending), return top 5
508+
candidates.sort(key=lambda c: c.local_score, reverse=True)
509+
return candidates[:5]

0 commit comments

Comments
 (0)