From b603a5c6d3412a3a0d08d9563637b92389f0ff50 Mon Sep 17 00:00:00 2001 From: Richard Abrich Date: Tue, 31 Mar 2026 18:44:03 -0400 Subject: [PATCH] feat: OCR text anchoring (Tier 1.5a) for grounding cascade Add Phase 5 text anchoring on top of Phase 4 state narrowing: - grounding.py: run_ocr() with pytesseract (optional dep, graceful fallback), ground_by_text() with tiered scoring (exact/case-insensitive/ substring/fuzzy) and nearby-text proximity boost, plus helper functions _char_overlap_ratio, _bbox_center, _bbox_distance. - demo_executor.py: _try_text_anchoring() method inserted before VLM grounder calls for click/double_click actions. Returns action if best OCR candidate scores > 0.85, otherwise falls through to Tier 2. - tests/test_text_anchoring.py: 21 tests covering all scoring tiers, proximity boost, edge cases, and graceful pytesseract fallback. All tests use mocked OCR results (no pytesseract required). Co-Authored-By: Claude Opus 4.6 (1M context) --- openadapt_evals/agents/demo_executor.py | 99 ++++++++- openadapt_evals/grounding.py | 218 ++++++++++++++++++- tests/test_text_anchoring.py | 274 ++++++++++++++++++++++++ 3 files changed, 588 insertions(+), 3 deletions(-) create mode 100644 tests/test_text_anchoring.py diff --git a/openadapt_evals/agents/demo_executor.py b/openadapt_evals/agents/demo_executor.py index 964cc4a..a39fb57 100644 --- a/openadapt_evals/agents/demo_executor.py +++ b/openadapt_evals/agents/demo_executor.py @@ -29,7 +29,13 @@ from openadapt_evals.adapters.base import BenchmarkAction, BenchmarkObservation from openadapt_evals.demo_library import Demo, DemoStep -from openadapt_evals.grounding import check_state_preconditions, verify_transition +from openadapt_evals.grounding import ( + GroundingTarget, + check_state_preconditions, + ground_by_text, + run_ocr, + verify_transition, +) try: from openadapt_evals.integrations.weave_integration import weave_op @@ -219,6 +225,79 @@ def run( return score, screenshots + def _try_text_anchoring( + self, + screenshot: bytes, + step: DemoStep, + ) -> BenchmarkAction | None: + """Attempt to ground a click via OCR text anchoring (Tier 1.5a). + + Creates a :class:`GroundingTarget` from the step and runs OCR-based + text matching. If the best candidate scores above ``0.85``, returns + a click action at those coordinates. Otherwise returns ``None`` so + the caller falls through to the VLM grounder (Tier 2). + + Args: + screenshot: Current screenshot PNG bytes. + step: The demo step being executed. + + Returns: + A :class:`BenchmarkAction` if text anchoring succeeds with high + confidence, or ``None`` to fall through. + """ + # Build target from step's grounding_target or description + if step.grounding_target is not None and isinstance( + step.grounding_target, GroundingTarget + ): + target = step.grounding_target + else: + description = step.description or step.target_description + if not description: + description = step.action_description + if not description: + return None + target = GroundingTarget(description=description) + + if not target.description: + return None + + # Run OCR and text grounding + ocr_results = run_ocr(screenshot) + if not ocr_results: + logger.debug("Tier 1.5a: no OCR results, falling through to VLM") + return None + + candidates = ground_by_text(screenshot, target, ocr_results=ocr_results) + if not candidates: + logger.debug( + "Tier 1.5a: no text matches for %r, falling through to VLM", + target.description, + ) + return None + + best = candidates[0] + if best.local_score > 0.85: + logger.info( + "Tier 1.5a (text anchor): %r matched %r at %s (score=%.2f)", + target.description, + best.matched_text, + best.point, + best.local_score, + ) + return BenchmarkAction( + type="click", + x=best.point[0], + y=best.point[1], + raw_action={"tier": 1.5, "source": "ocr_text_anchor"}, + ) + + logger.debug( + "Tier 1.5a: best score %.2f < 0.85 for %r, falling through", + best.local_score, + target.description, + ) + return None + @weave_op def _execute_step( self, @@ -228,6 +307,7 @@ def _execute_step( """Produce an action for a demo step using tiered intelligence. Tier 1: keyboard/type -> direct execution (no VLM). + Tier 1.5a: click -> OCR text anchoring (cheap, no VLM). Tier 2: click -> grounder finds element by description. Tier 3: recovery -> planner reasons about unexpected state. """ @@ -250,6 +330,12 @@ def _execute_step( return BenchmarkAction(type="type", text=text, raw_action={"tier": 1}) if step.action_type == "click": + # Tier 1.5a: try OCR text anchoring first + if obs.screenshot: + text_action = self._try_text_anchoring(obs.screenshot, step) + if text_action is not None: + return text_action + # Tier 2: grounder finds element by description description = step.description or step.target_description if not description: @@ -258,6 +344,17 @@ def _execute_step( return self._ground_click(obs, description) if step.action_type == "double_click": + # Tier 1.5a: try OCR text anchoring first + if obs.screenshot: + text_action = self._try_text_anchoring(obs.screenshot, step) + if text_action is not None: + return BenchmarkAction( + type="double_click", + x=text_action.x, + y=text_action.y, + raw_action=text_action.raw_action, + ) + description = step.description or step.target_description logger.info("Tier 2 (grounder): double-click %s", description) action = self._ground_click(obs, description) diff --git a/openadapt_evals/grounding.py b/openadapt_evals/grounding.py index 1d62422..19c3ced 100644 --- a/openadapt_evals/grounding.py +++ b/openadapt_evals/grounding.py @@ -1,14 +1,21 @@ -"""Grounding data model and state verification for the DemoExecutor cascade. +"""Grounding data model, state verification, and text anchoring for the +DemoExecutor cascade. Defines GroundingTarget (stored per click step in demo) and GroundingCandidate (produced by each tier during grounding). -Also provides state-narrowing functions (Phase 4 of the cascade): +Phase 4 — state-narrowing functions: - ``check_state_preconditions``: verify the screen matches expectations before grounding a click. - ``verify_transition``: verify the expected state change occurred after clicking. +Phase 5 — OCR text anchoring (Tier 1.5a): +- ``run_ocr``: extract text regions from a screenshot via pytesseract. +- ``ground_by_text``: match a GroundingTarget against OCR text with + tiered scoring (exact > case-insensitive > substring > fuzzy) and + nearby-text proximity boosting. + See docs/design/grounding_cascade_design_v3.md for the full architecture. """ @@ -293,3 +300,210 @@ def verify_transition( ) return True, "transition verified" + + +# --------------------------------------------------------------------------- +# Phase 5: OCR text anchoring (Tier 1.5a) +# --------------------------------------------------------------------------- + + +def _char_overlap_ratio(a: str, b: str) -> float: + """Return the ratio of shared characters between *a* and *b*. + + Uses character-level intersection (multiset) divided by the length of + the longer string. This is *not* edit distance — it is deliberately + cheap and order-insensitive. + + Returns: + A float in ``[0.0, 1.0]``. + """ + if not a or not b: + return 0.0 + # Build character frequency maps + from collections import Counter + + ca = Counter(a.lower()) + cb = Counter(b.lower()) + overlap = sum((ca & cb).values()) + return overlap / max(len(a), len(b)) + + +def _bbox_center(bbox: list[int] | tuple[int, ...]) -> tuple[float, float]: + """Return the center ``(cx, cy)`` of an ``[x1, y1, x2, y2]`` bbox.""" + x1, y1, x2, y2 = bbox[:4] + return ((x1 + x2) / 2.0, (y1 + y2) / 2.0) + + +def _bbox_distance( + a: list[int] | tuple[int, ...], + b: list[int] | tuple[int, ...], +) -> float: + """Euclidean distance between the centers of two bboxes.""" + import math + + ax, ay = _bbox_center(a) + bx, by = _bbox_center(b) + return math.sqrt((ax - bx) ** 2 + (ay - by) ** 2) + + +def run_ocr(screenshot: bytes) -> list[dict]: + """Run OCR on a screenshot and return detected text regions. + + Uses ``pytesseract`` when available. If it is not installed, returns + an empty list (graceful degradation — callers must handle ``[]``). + + Args: + screenshot: PNG image bytes. + + Returns: + List of dicts with keys ``"text"``, ``"bbox"`` (``[x1, y1, x2, y2]``), + and ``"confidence"`` (``0.0``–``1.0``). + """ + try: + import pytesseract # type: ignore[import-untyped] + except ImportError: + logger.debug("pytesseract not installed — returning empty OCR results") + return [] + + try: + from PIL import Image + import io + + image = Image.open(io.BytesIO(screenshot)) + data = pytesseract.image_to_data(image, output_type=pytesseract.Output.DICT) + except Exception as exc: + logger.warning("OCR failed: %s", exc) + return [] + + results: list[dict] = [] + n_boxes = len(data.get("text", [])) + for i in range(n_boxes): + text = data["text"][i].strip() + if not text: + continue + conf = float(data["conf"][i]) + if conf < 0: + continue + x = int(data["left"][i]) + y = int(data["top"][i]) + w = int(data["width"][i]) + h = int(data["height"][i]) + results.append({ + "text": text, + "bbox": [x, y, x + w, y + h], + "confidence": conf / 100.0, + }) + return results + + +def ground_by_text( + screenshot: bytes, + target: GroundingTarget, + ocr_results: list[dict] | None = None, +) -> list[GroundingCandidate]: + """Ground a target by matching its description against OCR text. + + This is **Tier 1.5a** in the grounding cascade — faster and cheaper + than a VLM call, but only works when the target contains readable + text. + + Scoring tiers (from highest to lowest): + + - **Exact match** (``0.95``): OCR text equals the target description. + - **Case-insensitive match** (``0.90``): Matches after lowercasing. + - **Substring match** (``0.70``): Target description is a substring of + the OCR text (or vice-versa), case-insensitive. + - **Fuzzy match** (``0.50``): Character-level overlap ratio > 80%. + + Candidates near ``target.nearby_text`` locations receive a ``+0.05`` + proximity boost (capped at ``1.0``). + + Args: + screenshot: PNG image bytes (used for OCR if *ocr_results* not + provided). + target: :class:`GroundingTarget` with at least a ``description``. + ocr_results: Pre-computed OCR results. When ``None``, + :func:`run_ocr` is called on *screenshot*. + + Returns: + Up to 5 :class:`GroundingCandidate` objects sorted by score + (highest first). Empty list if no matches found. + """ + if not target.description: + return [] + + if ocr_results is None: + ocr_results = run_ocr(screenshot) + + if not ocr_results: + return [] + + query = target.description + query_lower = query.lower() + + candidates: list[GroundingCandidate] = [] + + for item in ocr_results: + text = item.get("text", "") + bbox = item.get("bbox") + if not text or not bbox: + continue + + text_lower = text.lower() + score = 0.0 + + # Tiered scoring + if text == query: + score = 0.95 + elif text_lower == query_lower: + score = 0.90 + elif query_lower in text_lower or text_lower in query_lower: + score = 0.70 + elif _char_overlap_ratio(query, text) > 0.80: + score = 0.50 + else: + continue # No match + + cx, cy = _bbox_center(bbox) + candidates.append( + GroundingCandidate( + source="ocr", + point=(int(cx), int(cy)), + bbox=tuple(bbox[:4]), # type: ignore[arg-type] + local_score=score, + matched_text=text, + reasoning=f"OCR text match: {text!r} (score={score:.2f})", + ) + ) + + # Proximity boost: +0.05 for candidates near nearby_text locations + if target.nearby_text and candidates: + # Find bboxes for nearby_text items + nearby_bboxes: list[list[int]] = [] + for nearby in target.nearby_text: + nearby_lower = nearby.lower() + for item in ocr_results: + item_text = item.get("text", "").lower() + if nearby_lower in item_text and item.get("bbox"): + nearby_bboxes.append(item["bbox"]) + + if nearby_bboxes: + proximity_threshold = 300.0 # pixels + for candidate in candidates: + if candidate.bbox is None: + continue + for nb_bbox in nearby_bboxes: + dist = _bbox_distance(list(candidate.bbox), nb_bbox) + if dist < proximity_threshold: + candidate.local_score = min( + 1.0, candidate.local_score + 0.05 + ) + candidate.reasoning = ( + f"{candidate.reasoning}, " + f"nearby boost (+0.05)" + ) + break # One boost per candidate + + # Sort by score (descending), return top 5 + candidates.sort(key=lambda c: c.local_score, reverse=True) + return candidates[:5] diff --git a/tests/test_text_anchoring.py b/tests/test_text_anchoring.py new file mode 100644 index 0000000..496f6bb --- /dev/null +++ b/tests/test_text_anchoring.py @@ -0,0 +1,274 @@ +"""Tests for Phase 5 OCR text anchoring (Tier 1.5a). + +Covers: +- run_ocr graceful fallback when pytesseract is not installed +- ground_by_text scoring tiers (exact, case-insensitive, substring, fuzzy) +- ground_by_text nearby-text proximity boost +- ground_by_text sorting and top-5 limit +- DemoExecutor._try_text_anchoring integration (with mocked OCR) +""" + +from __future__ import annotations + +import math +from unittest.mock import patch + +import pytest + +from openadapt_evals.grounding import ( + GroundingCandidate, + GroundingTarget, + _bbox_center, + _bbox_distance, + _char_overlap_ratio, + ground_by_text, + run_ocr, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +DUMMY_SCREENSHOT = b"\x89PNG\r\n\x1a\n" # minimal PNG header + + +def _make_ocr_results(entries: list[tuple[str, list[int]]]) -> list[dict]: + """Build mock OCR results from (text, bbox) pairs.""" + return [ + {"text": text, "bbox": bbox, "confidence": 0.95} + for text, bbox in entries + ] + + +# --------------------------------------------------------------------------- +# Helper function tests +# --------------------------------------------------------------------------- + + +class TestHelperFunctions: + """Tests for _char_overlap_ratio, _bbox_center, _bbox_distance.""" + + def test_char_overlap_ratio_identical(self): + assert _char_overlap_ratio("hello", "hello") == 1.0 + + def test_char_overlap_ratio_empty(self): + assert _char_overlap_ratio("", "hello") == 0.0 + assert _char_overlap_ratio("hello", "") == 0.0 + + def test_char_overlap_ratio_partial(self): + ratio = _char_overlap_ratio("abcde", "abcfg") + # overlap: a, b, c -> 3; max_len = 5 -> 0.6 + assert ratio == pytest.approx(0.6) + + def test_char_overlap_ratio_case_insensitive(self): + assert _char_overlap_ratio("Hello", "HELLO") == 1.0 + + def test_bbox_center(self): + cx, cy = _bbox_center([10, 20, 30, 40]) + assert cx == 20.0 + assert cy == 30.0 + + def test_bbox_distance_same(self): + assert _bbox_distance([0, 0, 10, 10], [0, 0, 10, 10]) == 0.0 + + def test_bbox_distance_known(self): + # Centers: (5, 5) and (8, 9) -> distance = sqrt(9+16) = 5.0 + dist = _bbox_distance([0, 0, 10, 10], [6, 8, 10, 10]) + assert dist == pytest.approx(5.0) + + +# --------------------------------------------------------------------------- +# run_ocr +# --------------------------------------------------------------------------- + + +class TestRunOCR: + """Tests for run_ocr().""" + + def test_run_ocr_no_pytesseract(self): + """When pytesseract is not installed, returns empty list.""" + with patch.dict("sys.modules", {"pytesseract": None}): + result = run_ocr(DUMMY_SCREENSHOT) + assert result == [] + + +# --------------------------------------------------------------------------- +# ground_by_text +# --------------------------------------------------------------------------- + + +class TestGroundByText: + """Tests for ground_by_text().""" + + def test_ground_by_text_exact_match(self): + """Exact text match scores 0.95.""" + ocr = _make_ocr_results([("Save", [100, 200, 150, 220])]) + target = GroundingTarget(description="Save") + + candidates = ground_by_text(DUMMY_SCREENSHOT, target, ocr_results=ocr) + + assert len(candidates) == 1 + assert candidates[0].local_score == 0.95 + assert candidates[0].matched_text == "Save" + assert candidates[0].source == "ocr" + + def test_ground_by_text_case_insensitive(self): + """Case-insensitive match scores 0.90.""" + ocr = _make_ocr_results([("save", [100, 200, 150, 220])]) + target = GroundingTarget(description="Save") + + candidates = ground_by_text(DUMMY_SCREENSHOT, target, ocr_results=ocr) + + assert len(candidates) == 1 + assert candidates[0].local_score == 0.90 + + def test_ground_by_text_substring(self): + """Substring match scores 0.70.""" + ocr = _make_ocr_results([ + ("Save As...", [100, 200, 200, 220]), + ]) + target = GroundingTarget(description="Save") + + candidates = ground_by_text(DUMMY_SCREENSHOT, target, ocr_results=ocr) + + assert len(candidates) == 1 + assert candidates[0].local_score == 0.70 + + def test_ground_by_text_no_match(self): + """No matching text returns empty list.""" + ocr = _make_ocr_results([ + ("File", [10, 10, 50, 30]), + ("Edit", [60, 10, 100, 30]), + ]) + target = GroundingTarget(description="Export") + + candidates = ground_by_text(DUMMY_SCREENSHOT, target, ocr_results=ocr) + + assert candidates == [] + + def test_ground_by_text_sorted_by_score(self): + """Candidates are sorted by score descending.""" + ocr = _make_ocr_results([ + ("save as", [200, 200, 300, 220]), # substring of "Save" reversed -> substring match + ("Save", [100, 200, 150, 220]), # exact match -> 0.95 + ("SAVE", [300, 200, 350, 220]), # case-insensitive -> 0.90 + ]) + target = GroundingTarget(description="Save") + + candidates = ground_by_text(DUMMY_SCREENSHOT, target, ocr_results=ocr) + + assert len(candidates) == 3 + scores = [c.local_score for c in candidates] + assert scores == sorted(scores, reverse=True) + # Exact match should be first + assert candidates[0].local_score == 0.95 + assert candidates[0].matched_text == "Save" + + def test_ground_by_text_nearby_boost(self): + """Candidates near nearby_text locations get +0.05 boost.""" + # "OK" button near "Confirm" label + ocr = _make_ocr_results([ + ("OK", [100, 200, 140, 220]), + ("Confirm", [90, 170, 180, 190]), + ]) + target = GroundingTarget( + description="OK", + nearby_text=["Confirm"], + ) + + candidates = ground_by_text(DUMMY_SCREENSHOT, target, ocr_results=ocr) + + assert len(candidates) == 1 + # Exact match (0.95) + nearby boost (0.05) = 1.0 (capped) + assert candidates[0].local_score == 1.0 + assert "nearby boost" in candidates[0].reasoning + + def test_ground_by_text_no_nearby_boost_when_far(self): + """Candidates far from nearby_text do NOT get boosted.""" + ocr = _make_ocr_results([ + ("OK", [100, 200, 140, 220]), + ("Confirm", [2000, 2000, 2100, 2020]), # very far away + ]) + target = GroundingTarget( + description="OK", + nearby_text=["Confirm"], + ) + + candidates = ground_by_text(DUMMY_SCREENSHOT, target, ocr_results=ocr) + + assert len(candidates) == 1 + # No boost -- stays at exact match score + assert candidates[0].local_score == 0.95 + + def test_ground_by_text_with_mock_ocr(self): + """Integration test: ground_by_text with ocr_results=None uses run_ocr.""" + mock_results = _make_ocr_results([ + ("Submit", [400, 500, 480, 520]), + ]) + target = GroundingTarget(description="Submit") + + with patch( + "openadapt_evals.grounding.run_ocr", return_value=mock_results + ): + candidates = ground_by_text( + DUMMY_SCREENSHOT, target, ocr_results=None + ) + + assert len(candidates) == 1 + assert candidates[0].local_score == 0.95 + assert candidates[0].matched_text == "Submit" + cx, cy = _bbox_center([400, 500, 480, 520]) + assert candidates[0].point == (int(cx), int(cy)) + + def test_ground_by_text_empty_description(self): + """Empty description returns empty list.""" + ocr = _make_ocr_results([("Save", [100, 200, 150, 220])]) + target = GroundingTarget(description="") + + candidates = ground_by_text(DUMMY_SCREENSHOT, target, ocr_results=ocr) + assert candidates == [] + + def test_ground_by_text_empty_ocr(self): + """Empty OCR results returns empty list.""" + target = GroundingTarget(description="Save") + candidates = ground_by_text(DUMMY_SCREENSHOT, target, ocr_results=[]) + assert candidates == [] + + def test_ground_by_text_top_5_limit(self): + """At most 5 candidates returned.""" + ocr = _make_ocr_results([ + (f"Save {i}", [i * 50, 0, i * 50 + 40, 20]) + for i in range(10) + ]) + target = GroundingTarget(description="Save") + + candidates = ground_by_text(DUMMY_SCREENSHOT, target, ocr_results=ocr) + + assert len(candidates) <= 5 + + def test_ground_by_text_point_is_bbox_center(self): + """Candidate point is the center of the matched bbox.""" + bbox = [100, 200, 300, 400] + ocr = _make_ocr_results([("Click Me", bbox)]) + target = GroundingTarget(description="Click Me") + + candidates = ground_by_text(DUMMY_SCREENSHOT, target, ocr_results=ocr) + + assert len(candidates) == 1 + expected_cx, expected_cy = _bbox_center(bbox) + assert candidates[0].point == (int(expected_cx), int(expected_cy)) + + def test_ground_by_text_fuzzy_match(self): + """Fuzzy match (>80% char overlap) scores 0.50.""" + # "Savee" vs "Save" -> overlap: S,a,v,e = 4 chars; max_len=5 -> 0.8 + # Need >0.80, so try "Savee" (5 chars) vs "Savef" (5 chars) + # Actually let's use a better example: "Settings" vs "Settingx" + # overlap: S,e,t,t,i,n,g = 7; max_len = 8 -> 0.875 > 0.80 + ocr = _make_ocr_results([("Settingx", [100, 200, 200, 220])]) + target = GroundingTarget(description="Settings") + + candidates = ground_by_text(DUMMY_SCREENSHOT, target, ocr_results=ocr) + + assert len(candidates) == 1 + assert candidates[0].local_score == 0.50