|
1 | | -"""Grounding data model and state verification for the DemoExecutor cascade. |
| 1 | +"""Grounding data model, state verification, and text anchoring for the |
| 2 | +DemoExecutor cascade. |
2 | 3 |
|
3 | 4 | Defines GroundingTarget (stored per click step in demo) and |
4 | 5 | GroundingCandidate (produced by each tier during grounding). |
5 | 6 |
|
6 | | -Also provides state-narrowing functions (Phase 4 of the cascade): |
| 7 | +Phase 4 — state-narrowing functions: |
7 | 8 | - ``check_state_preconditions``: verify the screen matches expectations |
8 | 9 | before grounding a click. |
9 | 10 | - ``verify_transition``: verify the expected state change occurred after |
10 | 11 | clicking. |
11 | 12 |
|
| 13 | +Phase 5 — OCR text anchoring (Tier 1.5a): |
| 14 | +- ``run_ocr``: extract text regions from a screenshot via pytesseract. |
| 15 | +- ``ground_by_text``: match a GroundingTarget against OCR text with |
| 16 | + tiered scoring (exact > case-insensitive > substring > fuzzy) and |
| 17 | + nearby-text proximity boosting. |
| 18 | +
|
12 | 19 | See docs/design/grounding_cascade_design_v3.md for the full architecture. |
13 | 20 | """ |
14 | 21 |
|
@@ -293,3 +300,210 @@ def verify_transition( |
293 | 300 | ) |
294 | 301 |
|
295 | 302 | return True, "transition verified" |
| 303 | + |
| 304 | + |
| 305 | +# --------------------------------------------------------------------------- |
| 306 | +# Phase 5: OCR text anchoring (Tier 1.5a) |
| 307 | +# --------------------------------------------------------------------------- |
| 308 | + |
| 309 | + |
| 310 | +def _char_overlap_ratio(a: str, b: str) -> float: |
| 311 | + """Return the ratio of shared characters between *a* and *b*. |
| 312 | +
|
| 313 | + Uses character-level intersection (multiset) divided by the length of |
| 314 | + the longer string. This is *not* edit distance — it is deliberately |
| 315 | + cheap and order-insensitive. |
| 316 | +
|
| 317 | + Returns: |
| 318 | + A float in ``[0.0, 1.0]``. |
| 319 | + """ |
| 320 | + if not a or not b: |
| 321 | + return 0.0 |
| 322 | + # Build character frequency maps |
| 323 | + from collections import Counter |
| 324 | + |
| 325 | + ca = Counter(a.lower()) |
| 326 | + cb = Counter(b.lower()) |
| 327 | + overlap = sum((ca & cb).values()) |
| 328 | + return overlap / max(len(a), len(b)) |
| 329 | + |
| 330 | + |
| 331 | +def _bbox_center(bbox: list[int] | tuple[int, ...]) -> tuple[float, float]: |
| 332 | + """Return the center ``(cx, cy)`` of an ``[x1, y1, x2, y2]`` bbox.""" |
| 333 | + x1, y1, x2, y2 = bbox[:4] |
| 334 | + return ((x1 + x2) / 2.0, (y1 + y2) / 2.0) |
| 335 | + |
| 336 | + |
| 337 | +def _bbox_distance( |
| 338 | + a: list[int] | tuple[int, ...], |
| 339 | + b: list[int] | tuple[int, ...], |
| 340 | +) -> float: |
| 341 | + """Euclidean distance between the centers of two bboxes.""" |
| 342 | + import math |
| 343 | + |
| 344 | + ax, ay = _bbox_center(a) |
| 345 | + bx, by = _bbox_center(b) |
| 346 | + return math.sqrt((ax - bx) ** 2 + (ay - by) ** 2) |
| 347 | + |
| 348 | + |
| 349 | +def run_ocr(screenshot: bytes) -> list[dict]: |
| 350 | + """Run OCR on a screenshot and return detected text regions. |
| 351 | +
|
| 352 | + Uses ``pytesseract`` when available. If it is not installed, returns |
| 353 | + an empty list (graceful degradation — callers must handle ``[]``). |
| 354 | +
|
| 355 | + Args: |
| 356 | + screenshot: PNG image bytes. |
| 357 | +
|
| 358 | + Returns: |
| 359 | + List of dicts with keys ``"text"``, ``"bbox"`` (``[x1, y1, x2, y2]``), |
| 360 | + and ``"confidence"`` (``0.0``–``1.0``). |
| 361 | + """ |
| 362 | + try: |
| 363 | + import pytesseract # type: ignore[import-untyped] |
| 364 | + except ImportError: |
| 365 | + logger.debug("pytesseract not installed — returning empty OCR results") |
| 366 | + return [] |
| 367 | + |
| 368 | + try: |
| 369 | + from PIL import Image |
| 370 | + import io |
| 371 | + |
| 372 | + image = Image.open(io.BytesIO(screenshot)) |
| 373 | + data = pytesseract.image_to_data(image, output_type=pytesseract.Output.DICT) |
| 374 | + except Exception as exc: |
| 375 | + logger.warning("OCR failed: %s", exc) |
| 376 | + return [] |
| 377 | + |
| 378 | + results: list[dict] = [] |
| 379 | + n_boxes = len(data.get("text", [])) |
| 380 | + for i in range(n_boxes): |
| 381 | + text = data["text"][i].strip() |
| 382 | + if not text: |
| 383 | + continue |
| 384 | + conf = float(data["conf"][i]) |
| 385 | + if conf < 0: |
| 386 | + continue |
| 387 | + x = int(data["left"][i]) |
| 388 | + y = int(data["top"][i]) |
| 389 | + w = int(data["width"][i]) |
| 390 | + h = int(data["height"][i]) |
| 391 | + results.append({ |
| 392 | + "text": text, |
| 393 | + "bbox": [x, y, x + w, y + h], |
| 394 | + "confidence": conf / 100.0, |
| 395 | + }) |
| 396 | + return results |
| 397 | + |
| 398 | + |
| 399 | +def ground_by_text( |
| 400 | + screenshot: bytes, |
| 401 | + target: GroundingTarget, |
| 402 | + ocr_results: list[dict] | None = None, |
| 403 | +) -> list[GroundingCandidate]: |
| 404 | + """Ground a target by matching its description against OCR text. |
| 405 | +
|
| 406 | + This is **Tier 1.5a** in the grounding cascade — faster and cheaper |
| 407 | + than a VLM call, but only works when the target contains readable |
| 408 | + text. |
| 409 | +
|
| 410 | + Scoring tiers (from highest to lowest): |
| 411 | +
|
| 412 | + - **Exact match** (``0.95``): OCR text equals the target description. |
| 413 | + - **Case-insensitive match** (``0.90``): Matches after lowercasing. |
| 414 | + - **Substring match** (``0.70``): Target description is a substring of |
| 415 | + the OCR text (or vice-versa), case-insensitive. |
| 416 | + - **Fuzzy match** (``0.50``): Character-level overlap ratio > 80%. |
| 417 | +
|
| 418 | + Candidates near ``target.nearby_text`` locations receive a ``+0.05`` |
| 419 | + proximity boost (capped at ``1.0``). |
| 420 | +
|
| 421 | + Args: |
| 422 | + screenshot: PNG image bytes (used for OCR if *ocr_results* not |
| 423 | + provided). |
| 424 | + target: :class:`GroundingTarget` with at least a ``description``. |
| 425 | + ocr_results: Pre-computed OCR results. When ``None``, |
| 426 | + :func:`run_ocr` is called on *screenshot*. |
| 427 | +
|
| 428 | + Returns: |
| 429 | + Up to 5 :class:`GroundingCandidate` objects sorted by score |
| 430 | + (highest first). Empty list if no matches found. |
| 431 | + """ |
| 432 | + if not target.description: |
| 433 | + return [] |
| 434 | + |
| 435 | + if ocr_results is None: |
| 436 | + ocr_results = run_ocr(screenshot) |
| 437 | + |
| 438 | + if not ocr_results: |
| 439 | + return [] |
| 440 | + |
| 441 | + query = target.description |
| 442 | + query_lower = query.lower() |
| 443 | + |
| 444 | + candidates: list[GroundingCandidate] = [] |
| 445 | + |
| 446 | + for item in ocr_results: |
| 447 | + text = item.get("text", "") |
| 448 | + bbox = item.get("bbox") |
| 449 | + if not text or not bbox: |
| 450 | + continue |
| 451 | + |
| 452 | + text_lower = text.lower() |
| 453 | + score = 0.0 |
| 454 | + |
| 455 | + # Tiered scoring |
| 456 | + if text == query: |
| 457 | + score = 0.95 |
| 458 | + elif text_lower == query_lower: |
| 459 | + score = 0.90 |
| 460 | + elif query_lower in text_lower or text_lower in query_lower: |
| 461 | + score = 0.70 |
| 462 | + elif _char_overlap_ratio(query, text) > 0.80: |
| 463 | + score = 0.50 |
| 464 | + else: |
| 465 | + continue # No match |
| 466 | + |
| 467 | + cx, cy = _bbox_center(bbox) |
| 468 | + candidates.append( |
| 469 | + GroundingCandidate( |
| 470 | + source="ocr", |
| 471 | + point=(int(cx), int(cy)), |
| 472 | + bbox=tuple(bbox[:4]), # type: ignore[arg-type] |
| 473 | + local_score=score, |
| 474 | + matched_text=text, |
| 475 | + reasoning=f"OCR text match: {text!r} (score={score:.2f})", |
| 476 | + ) |
| 477 | + ) |
| 478 | + |
| 479 | + # Proximity boost: +0.05 for candidates near nearby_text locations |
| 480 | + if target.nearby_text and candidates: |
| 481 | + # Find bboxes for nearby_text items |
| 482 | + nearby_bboxes: list[list[int]] = [] |
| 483 | + for nearby in target.nearby_text: |
| 484 | + nearby_lower = nearby.lower() |
| 485 | + for item in ocr_results: |
| 486 | + item_text = item.get("text", "").lower() |
| 487 | + if nearby_lower in item_text and item.get("bbox"): |
| 488 | + nearby_bboxes.append(item["bbox"]) |
| 489 | + |
| 490 | + if nearby_bboxes: |
| 491 | + proximity_threshold = 300.0 # pixels |
| 492 | + for candidate in candidates: |
| 493 | + if candidate.bbox is None: |
| 494 | + continue |
| 495 | + for nb_bbox in nearby_bboxes: |
| 496 | + dist = _bbox_distance(list(candidate.bbox), nb_bbox) |
| 497 | + if dist < proximity_threshold: |
| 498 | + candidate.local_score = min( |
| 499 | + 1.0, candidate.local_score + 0.05 |
| 500 | + ) |
| 501 | + candidate.reasoning = ( |
| 502 | + f"{candidate.reasoning}, " |
| 503 | + f"nearby boost (+0.05)" |
| 504 | + ) |
| 505 | + break # One boost per candidate |
| 506 | + |
| 507 | + # Sort by score (descending), return top 5 |
| 508 | + candidates.sort(key=lambda c: c.local_score, reverse=True) |
| 509 | + return candidates[:5] |
0 commit comments