Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 98 additions & 1 deletion openadapt_evals/agents/demo_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,13 @@

from openadapt_evals.adapters.base import BenchmarkAction, BenchmarkObservation
from openadapt_evals.demo_library import Demo, DemoStep
from openadapt_evals.grounding import check_state_preconditions, verify_transition
from openadapt_evals.grounding import (
GroundingTarget,
check_state_preconditions,
ground_by_text,
run_ocr,
verify_transition,
)

try:
from openadapt_evals.integrations.weave_integration import weave_op
Expand Down Expand Up @@ -219,6 +225,79 @@ def run(

return score, screenshots

def _try_text_anchoring(
self,
screenshot: bytes,
step: DemoStep,
) -> BenchmarkAction | None:
"""Attempt to ground a click via OCR text anchoring (Tier 1.5a).

Creates a :class:`GroundingTarget` from the step and runs OCR-based
text matching. If the best candidate scores above ``0.85``, returns
a click action at those coordinates. Otherwise returns ``None`` so
the caller falls through to the VLM grounder (Tier 2).

Args:
screenshot: Current screenshot PNG bytes.
step: The demo step being executed.

Returns:
A :class:`BenchmarkAction` if text anchoring succeeds with high
confidence, or ``None`` to fall through.
"""
# Build target from step's grounding_target or description
if step.grounding_target is not None and isinstance(
step.grounding_target, GroundingTarget
):
target = step.grounding_target
else:
description = step.description or step.target_description
if not description:
description = step.action_description
if not description:
return None
target = GroundingTarget(description=description)

if not target.description:
return None

# Run OCR and text grounding
ocr_results = run_ocr(screenshot)
if not ocr_results:
logger.debug("Tier 1.5a: no OCR results, falling through to VLM")
return None

candidates = ground_by_text(screenshot, target, ocr_results=ocr_results)
if not candidates:
logger.debug(
"Tier 1.5a: no text matches for %r, falling through to VLM",
target.description,
)
return None

best = candidates[0]
if best.local_score > 0.85:
logger.info(
"Tier 1.5a (text anchor): %r matched %r at %s (score=%.2f)",
target.description,
best.matched_text,
best.point,
best.local_score,
)
return BenchmarkAction(
type="click",
x=best.point[0],
y=best.point[1],
raw_action={"tier": 1.5, "source": "ocr_text_anchor"},
)

logger.debug(
"Tier 1.5a: best score %.2f < 0.85 for %r, falling through",
best.local_score,
target.description,
)
return None

@weave_op
def _execute_step(
self,
Expand All @@ -228,6 +307,7 @@ def _execute_step(
"""Produce an action for a demo step using tiered intelligence.

Tier 1: keyboard/type -> direct execution (no VLM).
Tier 1.5a: click -> OCR text anchoring (cheap, no VLM).
Tier 2: click -> grounder finds element by description.
Tier 3: recovery -> planner reasons about unexpected state.
"""
Expand All @@ -250,6 +330,12 @@ def _execute_step(
return BenchmarkAction(type="type", text=text, raw_action={"tier": 1})

if step.action_type == "click":
# Tier 1.5a: try OCR text anchoring first
if obs.screenshot:
text_action = self._try_text_anchoring(obs.screenshot, step)
if text_action is not None:
return text_action

# Tier 2: grounder finds element by description
description = step.description or step.target_description
if not description:
Expand All @@ -258,6 +344,17 @@ def _execute_step(
return self._ground_click(obs, description)

if step.action_type == "double_click":
# Tier 1.5a: try OCR text anchoring first
if obs.screenshot:
text_action = self._try_text_anchoring(obs.screenshot, step)
if text_action is not None:
return BenchmarkAction(
type="double_click",
x=text_action.x,
y=text_action.y,
raw_action=text_action.raw_action,
)

description = step.description or step.target_description
logger.info("Tier 2 (grounder): double-click %s", description)
action = self._ground_click(obs, description)
Expand Down
218 changes: 216 additions & 2 deletions openadapt_evals/grounding.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
"""Grounding data model and state verification for the DemoExecutor cascade.
"""Grounding data model, state verification, and text anchoring for the
DemoExecutor cascade.

Defines GroundingTarget (stored per click step in demo) and
GroundingCandidate (produced by each tier during grounding).

Also provides state-narrowing functions (Phase 4 of the cascade):
Phase 4 — state-narrowing functions:
- ``check_state_preconditions``: verify the screen matches expectations
before grounding a click.
- ``verify_transition``: verify the expected state change occurred after
clicking.

Phase 5 — OCR text anchoring (Tier 1.5a):
- ``run_ocr``: extract text regions from a screenshot via pytesseract.
- ``ground_by_text``: match a GroundingTarget against OCR text with
tiered scoring (exact > case-insensitive > substring > fuzzy) and
nearby-text proximity boosting.

See docs/design/grounding_cascade_design_v3.md for the full architecture.
"""

Expand Down Expand Up @@ -293,3 +300,210 @@ def verify_transition(
)

return True, "transition verified"


# ---------------------------------------------------------------------------
# Phase 5: OCR text anchoring (Tier 1.5a)
# ---------------------------------------------------------------------------


def _char_overlap_ratio(a: str, b: str) -> float:
"""Return the ratio of shared characters between *a* and *b*.

Uses character-level intersection (multiset) divided by the length of
the longer string. This is *not* edit distance — it is deliberately
cheap and order-insensitive.

Returns:
A float in ``[0.0, 1.0]``.
"""
if not a or not b:
return 0.0
# Build character frequency maps
from collections import Counter

ca = Counter(a.lower())
cb = Counter(b.lower())
overlap = sum((ca & cb).values())
return overlap / max(len(a), len(b))


def _bbox_center(bbox: list[int] | tuple[int, ...]) -> tuple[float, float]:
"""Return the center ``(cx, cy)`` of an ``[x1, y1, x2, y2]`` bbox."""
x1, y1, x2, y2 = bbox[:4]
return ((x1 + x2) / 2.0, (y1 + y2) / 2.0)


def _bbox_distance(
a: list[int] | tuple[int, ...],
b: list[int] | tuple[int, ...],
) -> float:
"""Euclidean distance between the centers of two bboxes."""
import math

ax, ay = _bbox_center(a)
bx, by = _bbox_center(b)
return math.sqrt((ax - bx) ** 2 + (ay - by) ** 2)


def run_ocr(screenshot: bytes) -> list[dict]:
"""Run OCR on a screenshot and return detected text regions.

Uses ``pytesseract`` when available. If it is not installed, returns
an empty list (graceful degradation — callers must handle ``[]``).

Args:
screenshot: PNG image bytes.

Returns:
List of dicts with keys ``"text"``, ``"bbox"`` (``[x1, y1, x2, y2]``),
and ``"confidence"`` (``0.0``–``1.0``).
"""
try:
import pytesseract # type: ignore[import-untyped]
except ImportError:
logger.debug("pytesseract not installed — returning empty OCR results")
return []

try:
from PIL import Image
import io

image = Image.open(io.BytesIO(screenshot))
data = pytesseract.image_to_data(image, output_type=pytesseract.Output.DICT)
except Exception as exc:
logger.warning("OCR failed: %s", exc)
return []

results: list[dict] = []
n_boxes = len(data.get("text", []))
for i in range(n_boxes):
text = data["text"][i].strip()
if not text:
continue
conf = float(data["conf"][i])
if conf < 0:
continue
x = int(data["left"][i])
y = int(data["top"][i])
w = int(data["width"][i])
h = int(data["height"][i])
results.append({
"text": text,
"bbox": [x, y, x + w, y + h],
"confidence": conf / 100.0,
})
return results


def ground_by_text(
screenshot: bytes,
target: GroundingTarget,
ocr_results: list[dict] | None = None,
) -> list[GroundingCandidate]:
"""Ground a target by matching its description against OCR text.

This is **Tier 1.5a** in the grounding cascade — faster and cheaper
than a VLM call, but only works when the target contains readable
text.

Scoring tiers (from highest to lowest):

- **Exact match** (``0.95``): OCR text equals the target description.
- **Case-insensitive match** (``0.90``): Matches after lowercasing.
- **Substring match** (``0.70``): Target description is a substring of
the OCR text (or vice-versa), case-insensitive.
- **Fuzzy match** (``0.50``): Character-level overlap ratio > 80%.

Candidates near ``target.nearby_text`` locations receive a ``+0.05``
proximity boost (capped at ``1.0``).

Args:
screenshot: PNG image bytes (used for OCR if *ocr_results* not
provided).
target: :class:`GroundingTarget` with at least a ``description``.
ocr_results: Pre-computed OCR results. When ``None``,
:func:`run_ocr` is called on *screenshot*.

Returns:
Up to 5 :class:`GroundingCandidate` objects sorted by score
(highest first). Empty list if no matches found.
"""
if not target.description:
return []

if ocr_results is None:
ocr_results = run_ocr(screenshot)

if not ocr_results:
return []

query = target.description
query_lower = query.lower()

candidates: list[GroundingCandidate] = []

for item in ocr_results:
text = item.get("text", "")
bbox = item.get("bbox")
if not text or not bbox:
continue

text_lower = text.lower()
score = 0.0

# Tiered scoring
if text == query:
score = 0.95
elif text_lower == query_lower:
score = 0.90
elif query_lower in text_lower or text_lower in query_lower:
score = 0.70
elif _char_overlap_ratio(query, text) > 0.80:
score = 0.50
else:
continue # No match

cx, cy = _bbox_center(bbox)
candidates.append(
GroundingCandidate(
source="ocr",
point=(int(cx), int(cy)),
bbox=tuple(bbox[:4]), # type: ignore[arg-type]
local_score=score,
matched_text=text,
reasoning=f"OCR text match: {text!r} (score={score:.2f})",
)
)

# Proximity boost: +0.05 for candidates near nearby_text locations
if target.nearby_text and candidates:
# Find bboxes for nearby_text items
nearby_bboxes: list[list[int]] = []
for nearby in target.nearby_text:
nearby_lower = nearby.lower()
for item in ocr_results:
item_text = item.get("text", "").lower()
if nearby_lower in item_text and item.get("bbox"):
nearby_bboxes.append(item["bbox"])

if nearby_bboxes:
proximity_threshold = 300.0 # pixels
for candidate in candidates:
if candidate.bbox is None:
continue
for nb_bbox in nearby_bboxes:
dist = _bbox_distance(list(candidate.bbox), nb_bbox)
if dist < proximity_threshold:
candidate.local_score = min(
1.0, candidate.local_score + 0.05
)
candidate.reasoning = (
f"{candidate.reasoning}, "
f"nearby boost (+0.05)"
)
break # One boost per candidate

# Sort by score (descending), return top 5
candidates.sort(key=lambda c: c.local_score, reverse=True)
return candidates[:5]
Loading
Loading