Skip to content

Commit da923d7

Browse files
committed
feat(locators): add relations
1 parent e9c5a97 commit da923d7

2 files changed

Lines changed: 337 additions & 17 deletions

File tree

src/askui/models/locators.py

Lines changed: 222 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,49 @@
11
from abc import ABC, abstractmethod
22
from typing import Literal, TypeVar, Generic
3+
from typing_extensions import Self
4+
from dataclasses import dataclass
35

46

5-
SerializedLocator = TypeVar('SerializedLocator')
7+
SerializedLocator = TypeVar("SerializedLocator")
8+
9+
10+
ReferencePoint = Literal["center", "boundary", "any"]
11+
12+
13+
@dataclass(kw_only=True)
14+
class RelationBase(ABC):
15+
other_locator: "Locator"
16+
17+
def __str__(self):
18+
return f"{self.type} {self.other_locator}"
19+
20+
21+
@dataclass(kw_only=True)
22+
class NeighborRelation(RelationBase):
23+
type: Literal["above_of", "below_of", "right_of", "left_of"]
24+
index: int
25+
reference_point: ReferencePoint
26+
27+
def __str__(self):
28+
return f"{self.type} {self.other_locator} at index {self.index} in reference to {self.reference_point}"
29+
30+
31+
@dataclass(kw_only=True)
32+
class LogicalRelation(RelationBase):
33+
type: Literal["and", "or"]
34+
35+
36+
@dataclass(kw_only=True)
37+
class BoundingRelation(RelationBase):
38+
type: Literal["containing", "inside_of"]
39+
40+
41+
@dataclass(kw_only=True)
42+
class NearestToRelation(RelationBase):
43+
type: Literal["nearest_to"]
44+
45+
46+
Relation = NeighborRelation | LogicalRelation | BoundingRelation | NearestToRelation
647

748

849
class LocatorSerializer(Generic[SerializedLocator], ABC):
@@ -11,8 +52,124 @@ def serialize(self, locator: "Locator") -> SerializedLocator:
1152
raise NotImplementedError()
1253

1354

14-
class Locator:
15-
def serialize(self, serializer: LocatorSerializer[SerializedLocator]) -> SerializedLocator:
55+
class Relatable(ABC):
56+
def __init__(self) -> None:
57+
self.relations: list[Relation] = []
58+
59+
def above_of(
60+
self,
61+
other_locator: "Locator",
62+
index: int = 0,
63+
reference_point: Literal["center", "boundary", "any"] = "boundary",
64+
) -> Self:
65+
self.relations.append(
66+
NeighborRelation(
67+
type="above_of",
68+
other_locator=other_locator,
69+
index=index,
70+
reference_point=reference_point,
71+
)
72+
)
73+
return self
74+
75+
def below_of(
76+
self,
77+
other_locator: "Locator",
78+
index: int = 0,
79+
reference_point: Literal["center", "boundary", "any"] = "boundary",
80+
) -> Self:
81+
self.relations.append(
82+
NeighborRelation(
83+
type="below_of",
84+
other_locator=other_locator,
85+
index=index,
86+
reference_point=reference_point,
87+
)
88+
)
89+
return self
90+
91+
def right_of(
92+
self,
93+
other_locator: "Locator",
94+
index: int = 0,
95+
reference_point: Literal["center", "boundary", "any"] = "boundary",
96+
) -> Self:
97+
self.relations.append(
98+
NeighborRelation(
99+
type="right_of",
100+
other_locator=other_locator,
101+
index=index,
102+
reference_point=reference_point,
103+
)
104+
)
105+
return self
106+
107+
def left_of(
108+
self,
109+
other_locator: "Locator",
110+
index: int = 0,
111+
reference_point: Literal["center", "boundary", "any"] = "boundary",
112+
) -> Self:
113+
self.relations.append(
114+
NeighborRelation(
115+
type="left_of",
116+
other_locator=other_locator,
117+
index=index,
118+
reference_point=reference_point,
119+
)
120+
)
121+
return self
122+
123+
def containing(self, other_locator: "Locator") -> Self:
124+
self.relations.append(
125+
BoundingRelation(
126+
type="containing",
127+
other_locator=other_locator,
128+
)
129+
)
130+
return self
131+
132+
def inside_of(self, other_locator: "Locator") -> Self:
133+
self.relations.append(
134+
BoundingRelation(
135+
type="inside_of",
136+
other_locator=other_locator,
137+
)
138+
)
139+
return self
140+
141+
def nearest_to(self, other_locator: "Locator") -> Self:
142+
self.relations.append(
143+
NearestToRelation(
144+
type="nearest_to",
145+
other_locator=other_locator,
146+
)
147+
)
148+
return self
149+
150+
def and_(self, other_locator: "Locator") -> Self:
151+
self.relations.append(
152+
LogicalRelation(
153+
type="and",
154+
other_locator=other_locator,
155+
)
156+
)
157+
return self
158+
159+
def or_(self, other_locator: "Locator") -> Self:
160+
self.relations.append(
161+
LogicalRelation(
162+
type="or",
163+
other_locator=other_locator,
164+
)
165+
)
166+
return self
167+
168+
169+
class Locator(Relatable, ABC):
170+
def serialize(
171+
self, serializer: LocatorSerializer[SerializedLocator]
172+
) -> SerializedLocator:
16173
return serializer.serialize(self)
17174

18175

@@ -30,7 +187,11 @@ def __init__(self, class_name: Literal["text", "textfield"] | None = None):
30187
self.class_name = class_name
31188

32189
def __str__(self):
33-
return f'element with class "{self.class_name}"' if self.class_name else "element that has a class"
190+
return (
191+
f'element with class "{self.class_name}"'
192+
if self.class_name
193+
else "element that has a class"
194+
)
34195

35196

36197
class Text(Class):
@@ -61,38 +222,84 @@ def __str__(self):
61222

62223
class AskUiLocatorSerializer(LocatorSerializer[str]):
63224
_TEXT_DELIMITER = "<|string|>"
64-
225+
_RP_TO_INTERSECTION_AREA_MAPPING: dict[ReferencePoint, str] = {
226+
"center": "element_center_line",
227+
"boundary": "element_edge_area",
228+
"any": "display_edge_area",
229+
}
230+
_RELATION_TYPE_MAPPING: dict[str, str] = {
231+
"above_of": "above",
232+
"below_of": "below",
233+
"right_of": "right of",
234+
"left_of": "left of",
235+
"containing": "contains",
236+
"inside_of": "inside",
237+
"nearest_to": "nearest to",
238+
"and": "and",
239+
"or": "or",
240+
}
241+
65242
def serialize(self, locator: Locator) -> str:
243+
if len(locator.relations) > 1:
244+
raise NotImplementedError(
245+
"Serializing locators with multiple relations is not yet supported by AskUI"
246+
)
247+
66248
prefix = "Click on "
67249
if isinstance(locator, Text):
68-
return prefix + self._serialize_text(locator)
250+
serialized = prefix + self._serialize_text(locator)
69251
elif isinstance(locator, Class):
70-
return prefix + self._serialize_class(locator)
252+
serialized = prefix + self._serialize_class(locator)
71253
elif isinstance(locator, Description):
72-
return prefix + self._serialize_description(locator)
254+
serialized = prefix + self._serialize_description(locator)
73255
else:
74256
raise ValueError(f"Unsupported locator type: {type(locator)}")
75257

258+
if len(locator.relations) == 0:
259+
return serialized
260+
261+
return serialized + " " + self._serialize_relation(locator.relations[0])
262+
76263
def _serialize_class(self, class_: Class) -> str:
77264
return class_.class_name or "element"
78-
265+
79266
def _serialize_description(self, description: Description) -> str:
80-
return f'pta {self._TEXT_DELIMITER}{description.description}{self._TEXT_DELIMITER}'
267+
return (
268+
f"pta {self._TEXT_DELIMITER}{description.description}{self._TEXT_DELIMITER}"
269+
)
81270

82271
def _serialize_text(self, text: Text) -> str:
83272
match text.match_type:
84273
case "similar":
85-
return f'with text {self._TEXT_DELIMITER}{text.text}{self._TEXT_DELIMITER} that matches to {text.similarity_threshold} %'
274+
return f"with text {self._TEXT_DELIMITER}{text.text}{self._TEXT_DELIMITER} that matches to {text.similarity_threshold} %"
86275
case "exact":
87-
return f'equals text {self._TEXT_DELIMITER}{text.text}{self._TEXT_DELIMITER}'
276+
return f"equals text {self._TEXT_DELIMITER}{text.text}{self._TEXT_DELIMITER}"
88277
case "contains":
89-
return f'contain text {self._TEXT_DELIMITER}{text.text}{self._TEXT_DELIMITER}'
278+
return f"contain text {self._TEXT_DELIMITER}{text.text}{self._TEXT_DELIMITER}"
90279
case "regex":
91-
return f'match regex pattern {self._TEXT_DELIMITER}{text.text}{self._TEXT_DELIMITER}'
280+
return f"match regex pattern {self._TEXT_DELIMITER}{text.text}{self._TEXT_DELIMITER}"
281+
282+
def _serialize_relation(self, relation: Relation) -> str:
283+
match relation.type:
284+
case "above_of" | "below_of" | "right_of" | "left_of":
285+
assert isinstance(relation, NeighborRelation)
286+
return self._serialize_neighbor_relation(relation)
287+
case "containing" | "inside_of" | "nearest_to" | "and" | "or":
288+
return f"{self._RELATION_TYPE_MAPPING[relation.type]} {self.serialize(relation.other_locator)}"
289+
case _:
290+
raise ValueError(f"Unsupported relation type: {relation.type}")
291+
292+
def _serialize_neighbor_relation(self, relation: NeighborRelation) -> str:
293+
return f"index {relation.index} {self._RELATION_TYPE_MAPPING[relation.type]} intersection_area {self._RP_TO_INTERSECTION_AREA_MAPPING[relation.reference_point]} {self.serialize(relation.other_locator)}"
92294

93295

94296
class VlmLocatorSerializer(LocatorSerializer[str]):
95297
def serialize(self, locator: Locator) -> str:
298+
if len(locator.relations) > 0:
299+
raise NotImplementedError(
300+
"Serializing locators with relations is not yet supported for VLMs"
301+
)
302+
96303
if isinstance(locator, Text):
97304
return self._serialize_text(locator)
98305
elif isinstance(locator, Class):
@@ -107,7 +314,7 @@ def _serialize_class(self, class_: Class) -> str:
107314
return f"an arbitrary {class_.class_name} shown"
108315
else:
109316
return "an arbitrary ui element (e.g., text, button, textfield, etc.)"
110-
317+
111318
def _serialize_description(self, description: Description) -> str:
112319
return description.description
113320

0 commit comments

Comments
 (0)