Skip to content

Commit 2b252c6

Browse files
committed
feat: add description, class and text locators
1 parent b6e41c1 commit 2b252c6

14 files changed

Lines changed: 297 additions & 67 deletions

File tree

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,9 @@ distribution = true
4040

4141
[tool.pdm.scripts]
4242
test = "pytest"
43-
"test:unit" = "pytest tests/unit"
43+
"test:e2e" = "pytest tests/e2e"
4444
"test:integration" = "pytest tests/integration"
45+
"test:unit" = "pytest tests/unit"
4546
sort = "isort ."
4647
format = "black ."
4748
lint = "ruff check ."

src/askui/agent.py

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pydantic import Field, validate_call
66

77
from askui.container import telemetry
8+
from askui.models.locators import Locator
89

910
from .tools.askui.askui_controller import (
1011
AskUiControllerClient,
@@ -15,7 +16,7 @@
1516
from .models.anthropic.claude import ClaudeHandler
1617
from .logger import logger, configure_logging
1718
from .tools.toolbox import AgentToolbox
18-
from .models.router import ModelRouter
19+
from .models.router import ModelRouter, Point
1920
from .reporting.report import SimpleReportGenerator
2021
import time
2122
from dotenv import load_dotenv
@@ -59,13 +60,13 @@ def _check_askui_controller_enabled(self) -> None:
5960
"AskUI Controller is not initialized. Please, set `enable_askui_controller` to `True` when initializing the `VisionAgent`."
6061
)
6162

62-
@telemetry.record_call(exclude={"instruction"})
63-
def click(self, instruction: Optional[str] = None, button: Literal['left', 'middle', 'right'] = 'left', repeat: int = 1, model_name: Optional[str] = None) -> None:
63+
@telemetry.record_call(exclude={"locator"})
64+
def click(self, locator: Optional[str | Locator] = None, button: Literal['left', 'middle', 'right'] = 'left', repeat: int = 1, model_name: Optional[str] = None) -> None:
6465
"""
65-
Simulates a mouse click on the user interface element identified by the provided instruction.
66+
Simulates a mouse click on the user interface element identified by the provided locator.
6667
6768
Parameters:
68-
instruction (str | None): The identifier or description of the element to click.
69+
locator (str | Locator | None): The identifier or description of the element to click.
6970
button ('left' | 'middle' | 'right'): Specifies which mouse button to click. Defaults to 'left'.
7071
repeat (int): The number of times to click. Must be greater than 0. Defaults to 1.
7172
model_name (str | None): The model name to be used for element detection. Optional.
@@ -92,29 +93,34 @@ def click(self, instruction: Optional[str] = None, button: Literal['left', 'midd
9293
msg = f'{button} ' + msg
9394
if repeat > 1:
9495
msg += f' {repeat}x times'
95-
if instruction is not None:
96-
msg += f' on "{instruction}"'
96+
if locator is not None:
97+
msg += f' on "{locator}"'
9798
self.report.add_message("User", msg)
98-
if instruction is not None:
99-
logger.debug("VisionAgent received instruction to click '%s'", instruction)
100-
self.__mouse_move(instruction, model_name)
99+
if locator is not None:
100+
logger.debug("VisionAgent received instruction to click '%s'", locator)
101+
self._mouse_move(locator, model_name)
101102
self.client.click(button, repeat) # type: ignore
102-
103-
def __mouse_move(self, instruction: str, model_name: Optional[str] = None) -> None:
104-
self._check_askui_controller_enabled()
105-
screenshot = self.client.screenshot() # type: ignore
106-
x, y = self.model_router.locate(screenshot, instruction, model_name)
103+
104+
def locate(self, locator: str | Locator, screenshot: Optional[Image.Image] = None, model_name: Optional[str] = None) -> Point:
105+
if screenshot is None:
106+
self._check_askui_controller_enabled()
107+
screenshot = self.client.screenshot() # type: ignore
108+
point = self.model_router.locate(screenshot, locator, model_name)
107109
if self.report is not None:
108-
self.report.add_message("ModelRouter", f"locate: ({x}, {y})")
109-
self.client.mouse(x, y) # type: ignore
110+
self.report.add_message("ModelRouter", f"locate: ({point[0]}, {point[1]})")
111+
return point
112+
113+
def _mouse_move(self, locator: str | Locator, model_name: Optional[str] = None) -> None:
114+
point = self.locate(locator=locator, model_name=model_name)
115+
self.client.mouse(point[0], point[1]) # type: ignore
110116

111-
@telemetry.record_call(exclude={"instruction"})
112-
def mouse_move(self, instruction: str, model_name: Optional[str] = None) -> None:
117+
@telemetry.record_call(exclude={"locator"})
118+
def mouse_move(self, locator: str | Locator, model_name: Optional[str] = None) -> None:
113119
"""
114-
Moves the mouse cursor to the UI element identified by the provided instruction.
120+
Moves the mouse cursor to the UI element identified by the provided locator.
115121
116122
Parameters:
117-
instruction (str): The identifier or description of the element to move to.
123+
locator (str | Locator): The identifier or description of the element to move to.
118124
model_name (str | None): The model name to be used for element detection. Optional.
119125
120126
Example:
@@ -126,9 +132,9 @@ def mouse_move(self, instruction: str, model_name: Optional[str] = None) -> None
126132
```
127133
"""
128134
if self.report is not None:
129-
self.report.add_message("User", f'mouse_move: "{instruction}"')
130-
logger.debug("VisionAgent received instruction to mouse_move '%s'", instruction)
131-
self.__mouse_move(instruction, model_name)
135+
self.report.add_message("User", f'mouse_move: "{locator}"')
136+
logger.debug("VisionAgent received instruction to mouse_move to '%s'", locator)
137+
self._mouse_move(locator, model_name)
132138

133139
@telemetry.record_call()
134140
def mouse_scroll(self, x: int, y: int) -> None:

src/askui/chat/__main__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def rerun():
211211
image=screenshot_with_crosshair,
212212
)
213213
agent.mouse_move(
214-
instruction=element_description.replace('"', ""),
214+
locator=element_description.replace('"', ""),
215215
model_name="anthropic-claude-3-5-sonnet-20241022",
216216
)
217217
else:

src/askui/models/__init__.py

Whitespace-only changes.

src/askui/models/askui/ai_element_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ def __init__(self, additional_ai_element_locations: Optional[List[pathlib.Path]]
8787

8888
logger.debug("AI Element locations: %s", self.ai_element_locations)
8989

90-
def find(self, name: str):
91-
ai_elements = []
90+
def find(self, name: str) -> list[AiElement]:
91+
ai_elements: list[AiElement] = []
9292

9393
for location in self.ai_element_locations:
9494
path = pathlib.Path(location)
@@ -105,4 +105,4 @@ def find(self, name: str):
105105
if ai_element.metadata.name == name:
106106
ai_elements.append(ai_element)
107107

108-
return ai_elements
108+
return ai_elements

src/askui/models/askui/api.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
import requests
55

66
from PIL import Image
7-
from typing import List, Union
7+
from typing import Any, List, Union
88
from askui.models.askui.ai_element_utils import AiElement, AiElementCollection, AiElementNotFound
9+
from askui.models.locators import AskUiLocatorSerializer, Locator
910
from askui.utils import image_to_base64
1011
from askui.logger import logger
1112

@@ -23,6 +24,7 @@ def __init__(self):
2324
self.authenticated = False
2425

2526
self.ai_element_collection = AiElementCollection()
27+
self._locator_serializer = AskUiLocatorSerializer()
2628

2729

2830

@@ -32,7 +34,7 @@ def _build_askui_token_auth_header(self, bearer_token: str | None = None) -> dic
3234
token_base64 = base64.b64encode(self.token.encode("utf-8")).decode("utf-8")
3335
return {"Authorization": f"Basic {token_base64}"}
3436

35-
def _build_custom_elements(self, ai_elements: List[AiElement] | None):
37+
def _build_custom_elements(self, ai_elements: List[AiElement] | None) -> list[dict[str, str]]:
3638
"""
3739
Converts AiElements to the CustomElementDto format expected by the backend.
3840
@@ -43,9 +45,9 @@ def _build_custom_elements(self, ai_elements: List[AiElement] | None):
4345
dict: Custom elements in the format expected by the backend
4446
"""
4547
if not ai_elements:
46-
return {}
48+
return []
4749

48-
custom_elements = []
50+
custom_elements: list[dict[str, str]] = []
4951
for element in ai_elements:
5052
custom_element = {
5153
"customImage": "," + image_to_base64(element.image),
@@ -54,24 +56,22 @@ def _build_custom_elements(self, ai_elements: List[AiElement] | None):
5456
}
5557
custom_elements.append(custom_element)
5658

57-
return {
58-
"customElements": custom_elements
59-
}
60-
def __build_model_composition(self):
61-
return {}
59+
return custom_elements
6260

6361
def __build_base_url(self, endpoint: str = "inference") -> str:
6462
return f"{self.inference_endpoint}/api/v3/workspaces/{self.workspace_id}/{endpoint}"
6563

66-
def predict(self, image: Union[pathlib.Path, Image.Image], locator: str, ai_elements: List[pathlib.Path] = None) -> tuple[int | None, int | None]:
64+
def predict(self, image: Union[pathlib.Path, Image.Image], locator: str | Locator, ai_elements: List[AiElement] | None = None) -> tuple[int | None, int | None]:
65+
json: dict[str, Any] = {
66+
"image": f",{image_to_base64(image)}",
67+
}
68+
if locator is not None:
69+
json["instruction"] = locator if isinstance(locator, str) else locator.serialize(serializer=self._locator_serializer)
70+
if ai_elements is not None:
71+
json["customElements"] = self._build_custom_elements(ai_elements)
6772
response = requests.post(
6873
self.__build_base_url(),
69-
json={
70-
"image": f",{image_to_base64(image)}",
71-
**({"instruction": locator} if locator is not None else {}),
72-
**self.__build_model_composition(),
73-
**self._build_custom_elements(ai_elements)
74-
},
74+
json=json,
7575
headers={"Content-Type": "application/json", **self._build_askui_token_auth_header()},
7676
timeout=30,
7777
)
@@ -83,23 +83,23 @@ def predict(self, image: Union[pathlib.Path, Image.Image], locator: str, ai_elem
8383
actions = [el for el in content["data"]["actions"] if el["inputEvent"] == "MOUSE_MOVE"]
8484
if len(actions) == 0:
8585
return None, None
86-
position = actions[0]["position"]
8786

87+
position = actions[0]["position"]
8888
return int(position["x"]), int(position["y"])
8989

90-
def locate_pta_prediction(self, image: Union[pathlib.Path, Image.Image], locator: str) -> tuple[int | None, int | None]:
91-
askui_locator = f'Click on pta "{locator}"'
92-
return self.predict(image, askui_locator)
90+
def locate_pta_prediction(self, image: Union[pathlib.Path, Image.Image], locator: str | Locator) -> tuple[int | None, int | None]:
91+
_locator = f'Click on pta "{locator}"' if isinstance(locator, str) else locator
92+
return self.predict(image, _locator)
9393

94-
def locate_ocr_prediction(self, image: Union[pathlib.Path, Image.Image], locator: str) -> tuple[int | None, int | None]:
95-
askui_locator = f'Click on with text "{locator}"'
96-
return self.predict(image, askui_locator)
94+
def locate_ocr_prediction(self, image: Union[pathlib.Path, Image.Image], locator: str | Locator) -> tuple[int | None, int | None]:
95+
_locator = f'Click on with text "{locator}"' if isinstance(locator, str) else locator
96+
return self.predict(image, _locator)
9797

9898
def locate_ai_element_prediction(self, image: Union[pathlib.Path, Image.Image], name: str) -> tuple[int | None, int | None]:
9999
ai_elements = self.ai_element_collection.find(name)
100100

101101
if len(ai_elements) == 0:
102102
raise AiElementNotFound(f"Could not locate AI element with name '{name}'")
103103

104-
askui_instruction = f'Click on custom element with text "{name}"'
105-
return self.predict(image, askui_instruction, ai_elements=ai_elements)
104+
_locator = f'Click on custom element with text "{name}"'
105+
return self.predict(image, _locator, ai_elements=ai_elements)

src/askui/models/locators.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Literal, TypeVar, Generic
3+
4+
5+
SerializedLocator = TypeVar('SerializedLocator')
6+
7+
8+
class LocatorSerializer(Generic[SerializedLocator], ABC):
9+
@abstractmethod
10+
def serialize(self, locator: "Locator") -> SerializedLocator:
11+
raise NotImplementedError()
12+
13+
14+
class Locator:
15+
def serialize(self, serializer: LocatorSerializer[SerializedLocator]) -> SerializedLocator:
16+
return serializer.serialize(self)
17+
18+
19+
class Description(Locator):
20+
def __init__(self, description: str):
21+
self.description = description
22+
23+
def __str__(self):
24+
return f'element with description "{self.description}"'
25+
26+
27+
class Class(Locator):
28+
# None is used to indicate that it is an element with a class but not a specific class
29+
def __init__(self, class_name: Literal["text", "textfield"] | None = None):
30+
self.class_name = class_name
31+
32+
def __str__(self):
33+
return f'element with class "{self.class_name}"' if self.class_name else "element that has a class"
34+
35+
36+
class Text(Class):
37+
def __init__(
38+
self,
39+
text: str | None = None,
40+
match_type: Literal["similar", "exact", "contains", "regex"] = "similar",
41+
similarity_threshold: int = 70,
42+
):
43+
super().__init__(class_name="text")
44+
self.text = text
45+
self.match_type = match_type
46+
self.similarity_threshold = similarity_threshold
47+
48+
def __str__(self):
49+
result = "text "
50+
match self.match_type:
51+
case "similar":
52+
result += f'similar to "{self.text}" (similarity >= {self.similarity_threshold}%)'
53+
case "exact":
54+
result += f'"{self.text}"'
55+
case "contains":
56+
result += f'containing text "{self.text}"'
57+
case "regex":
58+
result += f'matching regex "{self.text}"'
59+
return result
60+
61+
62+
class AskUiLocatorSerializer(LocatorSerializer[str]):
63+
_TEXT_DELIMITER = "<|string|>"
64+
65+
def serialize(self, locator: Locator) -> str:
66+
prefix = "Click on "
67+
if isinstance(locator, Text):
68+
return prefix + self._serialize_text(locator)
69+
elif isinstance(locator, Class):
70+
return prefix + self._serialize_class(locator)
71+
elif isinstance(locator, Description):
72+
return prefix + self._serialize_description(locator)
73+
else:
74+
raise ValueError(f"Unsupported locator type: {type(locator)}")
75+
76+
def _serialize_class(self, class_: Class) -> str:
77+
return class_.class_name or "element"
78+
79+
def _serialize_description(self, description: Description) -> str:
80+
return f'pta {self._TEXT_DELIMITER}{description.description}{self._TEXT_DELIMITER}'
81+
82+
def _serialize_text(self, text: Text) -> str:
83+
match text.match_type:
84+
case "similar":
85+
return f'with text {self._TEXT_DELIMITER}{text.text}{self._TEXT_DELIMITER} that matches to {text.similarity_threshold} %'
86+
case "exact":
87+
return f'equals text {self._TEXT_DELIMITER}{text.text}{self._TEXT_DELIMITER}'
88+
case "contains":
89+
return f'contain text {self._TEXT_DELIMITER}{text.text}{self._TEXT_DELIMITER}'
90+
case "regex":
91+
return f'match regex pattern {self._TEXT_DELIMITER}{text.text}{self._TEXT_DELIMITER}'
92+
93+
94+
class VlmLocatorSerializer(LocatorSerializer[str]):
95+
def serialize(self, locator: Locator) -> str:
96+
if isinstance(locator, Text):
97+
return self._serialize_text(locator)
98+
elif isinstance(locator, Class):
99+
return self._serialize_class(locator)
100+
elif isinstance(locator, Description):
101+
return self._serialize_description(locator)
102+
else:
103+
raise ValueError(f"Unsupported locator type: {type(locator)}")
104+
105+
def _serialize_class(self, class_: Class) -> str:
106+
return class_.class_name or "ui element"
107+
108+
def _serialize_description(self, description: Description) -> str:
109+
return description.description
110+
111+
def _serialize_text(self, text: Text) -> str:
112+
if text.match_type == "similar":
113+
return f'text similar to "{text.text}"'
114+
115+
return str(text)

0 commit comments

Comments
 (0)