11from abc import ABC , abstractmethod
22from typing import Literal , TypeVar , Generic
3+ from typing_extensions import Self
4+ from dataclasses import dataclass
35
46
5- SerializedLocator = TypeVar ('SerializedLocator' )
7+ SerializedLocator = TypeVar ("SerializedLocator" )
8+
9+
10+ ReferencePoint = Literal ["center" , "boundary" , "any" ]
11+
12+
13+ @dataclass (kw_only = True )
14+ class RelationBase (ABC ):
15+ other_locator : "Locator"
16+
17+ def __str__ (self ):
18+ return f"{ self .type } { self .other_locator } "
19+
20+
21+ @dataclass (kw_only = True )
22+ class NeighborRelation (RelationBase ):
23+ type : Literal ["above_of" , "below_of" , "right_of" , "left_of" ]
24+ index : int
25+ reference_point : ReferencePoint
26+
27+ def __str__ (self ):
28+ return f"{ self .type } { self .other_locator } at index { self .index } in reference to { self .reference_point } "
29+
30+
31+ @dataclass (kw_only = True )
32+ class LogicalRelation (RelationBase ):
33+ type : Literal ["and" , "or" ]
34+
35+
36+ @dataclass (kw_only = True )
37+ class BoundingRelation (RelationBase ):
38+ type : Literal ["containing" , "inside_of" ]
39+
40+
41+ @dataclass (kw_only = True )
42+ class NearestToRelation (RelationBase ):
43+ type : Literal ["nearest_to" ]
44+
45+
46+ Relation = NeighborRelation | LogicalRelation | BoundingRelation | NearestToRelation
647
748
849class LocatorSerializer (Generic [SerializedLocator ], ABC ):
@@ -11,8 +52,124 @@ def serialize(self, locator: "Locator") -> SerializedLocator:
1152 raise NotImplementedError ()
1253
1354
14- class Locator :
15- def serialize (self , serializer : LocatorSerializer [SerializedLocator ]) -> SerializedLocator :
55+ class Relatable (ABC ):
56+ def __init__ (self ) -> None :
57+ self .relations : list [Relation ] = []
58+
59+ def above_of (
60+ self ,
61+ other_locator : "Locator" ,
62+ index : int = 0 ,
63+ reference_point : Literal ["center" , "boundary" , "any" ] = "boundary" ,
64+ ) -> Self :
65+ self .relations .append (
66+ NeighborRelation (
67+ type = "above_of" ,
68+ other_locator = other_locator ,
69+ index = index ,
70+ reference_point = reference_point ,
71+ )
72+ )
73+ return self
74+
75+ def below_of (
76+ self ,
77+ other_locator : "Locator" ,
78+ index : int = 0 ,
79+ reference_point : Literal ["center" , "boundary" , "any" ] = "boundary" ,
80+ ) -> Self :
81+ self .relations .append (
82+ NeighborRelation (
83+ type = "below_of" ,
84+ other_locator = other_locator ,
85+ index = index ,
86+ reference_point = reference_point ,
87+ )
88+ )
89+ return self
90+
91+ def right_of (
92+ self ,
93+ other_locator : "Locator" ,
94+ index : int = 0 ,
95+ reference_point : Literal ["center" , "boundary" , "any" ] = "boundary" ,
96+ ) -> Self :
97+ self .relations .append (
98+ NeighborRelation (
99+ type = "right_of" ,
100+ other_locator = other_locator ,
101+ index = index ,
102+ reference_point = reference_point ,
103+ )
104+ )
105+ return self
106+
107+ def left_of (
108+ self ,
109+ other_locator : "Locator" ,
110+ index : int = 0 ,
111+ reference_point : Literal ["center" , "boundary" , "any" ] = "boundary" ,
112+ ) -> Self :
113+ self .relations .append (
114+ NeighborRelation (
115+ type = "left_of" ,
116+ other_locator = other_locator ,
117+ index = index ,
118+ reference_point = reference_point ,
119+ )
120+ )
121+ return self
122+
123+ def containing (self , other_locator : "Locator" ) -> Self :
124+ self .relations .append (
125+ BoundingRelation (
126+ type = "containing" ,
127+ other_locator = other_locator ,
128+ )
129+ )
130+ return self
131+
132+ def inside_of (self , other_locator : "Locator" ) -> Self :
133+ self .relations .append (
134+ BoundingRelation (
135+ type = "inside_of" ,
136+ other_locator = other_locator ,
137+ )
138+ )
139+ return self
140+
141+ def nearest_to (self , other_locator : "Locator" ) -> Self :
142+ self .relations .append (
143+ NearestToRelation (
144+ type = "nearest_to" ,
145+ other_locator = other_locator ,
146+ )
147+ )
148+ return self
149+
150+ def and_ (self , other_locator : "Locator" ) -> Self :
151+ self .relations .append (
152+ LogicalRelation (
153+ type = "and" ,
154+ other_locator = other_locator ,
155+ )
156+ )
157+ return self
158+
159+ def or_ (self , other_locator : "Locator" ) -> Self :
160+ self .relations .append (
161+ LogicalRelation (
162+ type = "or" ,
163+ other_locator = other_locator ,
164+ )
165+ )
166+ return self
167+
168+
169+ class Locator (Relatable , ABC ):
170+ def serialize (
171+ self , serializer : LocatorSerializer [SerializedLocator ]
172+ ) -> SerializedLocator :
16173 return serializer .serialize (self )
17174
18175
@@ -30,7 +187,11 @@ def __init__(self, class_name: Literal["text", "textfield"] | None = None):
30187 self .class_name = class_name
31188
32189 def __str__ (self ):
33- return f'element with class "{ self .class_name } "' if self .class_name else "element that has a class"
190+ return (
191+ f'element with class "{ self .class_name } "'
192+ if self .class_name
193+ else "element that has a class"
194+ )
34195
35196
36197class Text (Class ):
@@ -61,38 +222,84 @@ def __str__(self):
61222
62223class AskUiLocatorSerializer (LocatorSerializer [str ]):
63224 _TEXT_DELIMITER = "<|string|>"
64-
225+ _RP_TO_INTERSECTION_AREA_MAPPING : dict [ReferencePoint , str ] = {
226+ "center" : "element_center_line" ,
227+ "boundary" : "element_edge_area" ,
228+ "any" : "display_edge_area" ,
229+ }
230+ _RELATION_TYPE_MAPPING : dict [str , str ] = {
231+ "above_of" : "above" ,
232+ "below_of" : "below" ,
233+ "right_of" : "right of" ,
234+ "left_of" : "left of" ,
235+ "containing" : "contains" ,
236+ "inside_of" : "inside" ,
237+ "nearest_to" : "nearest to" ,
238+ "and" : "and" ,
239+ "or" : "or" ,
240+ }
241+
65242 def serialize (self , locator : Locator ) -> str :
243+ if len (locator .relations ) > 1 :
244+ raise NotImplementedError (
245+ "Serializing locators with multiple relations is not yet supported by AskUI"
246+ )
247+
66248 prefix = "Click on "
67249 if isinstance (locator , Text ):
68- return prefix + self ._serialize_text (locator )
250+ serialized = prefix + self ._serialize_text (locator )
69251 elif isinstance (locator , Class ):
70- return prefix + self ._serialize_class (locator )
252+ serialized = prefix + self ._serialize_class (locator )
71253 elif isinstance (locator , Description ):
72- return prefix + self ._serialize_description (locator )
254+ serialized = prefix + self ._serialize_description (locator )
73255 else :
74256 raise ValueError (f"Unsupported locator type: { type (locator )} " )
75257
258+ if len (locator .relations ) == 0 :
259+ return serialized
260+
261+ return serialized + " " + self ._serialize_relation (locator .relations [0 ])
262+
76263 def _serialize_class (self , class_ : Class ) -> str :
77264 return class_ .class_name or "element"
78-
265+
79266 def _serialize_description (self , description : Description ) -> str :
80- return f'pta { self ._TEXT_DELIMITER } { description .description } { self ._TEXT_DELIMITER } '
267+ return (
268+ f"pta { self ._TEXT_DELIMITER } { description .description } { self ._TEXT_DELIMITER } "
269+ )
81270
82271 def _serialize_text (self , text : Text ) -> str :
83272 match text .match_type :
84273 case "similar" :
85- return f' with text { self ._TEXT_DELIMITER } { text .text } { self ._TEXT_DELIMITER } that matches to { text .similarity_threshold } %'
274+ return f" with text { self ._TEXT_DELIMITER } { text .text } { self ._TEXT_DELIMITER } that matches to { text .similarity_threshold } %"
86275 case "exact" :
87- return f' equals text { self ._TEXT_DELIMITER } { text .text } { self ._TEXT_DELIMITER } '
276+ return f" equals text { self ._TEXT_DELIMITER } { text .text } { self ._TEXT_DELIMITER } "
88277 case "contains" :
89- return f' contain text { self ._TEXT_DELIMITER } { text .text } { self ._TEXT_DELIMITER } '
278+ return f" contain text { self ._TEXT_DELIMITER } { text .text } { self ._TEXT_DELIMITER } "
90279 case "regex" :
91- return f'match regex pattern { self ._TEXT_DELIMITER } { text .text } { self ._TEXT_DELIMITER } '
280+ return f"match regex pattern { self ._TEXT_DELIMITER } { text .text } { self ._TEXT_DELIMITER } "
281+
282+ def _serialize_relation (self , relation : Relation ) -> str :
283+ match relation .type :
284+ case "above_of" | "below_of" | "right_of" | "left_of" :
285+ assert isinstance (relation , NeighborRelation )
286+ return self ._serialize_neighbor_relation (relation )
287+ case "containing" | "inside_of" | "nearest_to" | "and" | "or" :
288+ return f"{ self ._RELATION_TYPE_MAPPING [relation .type ]} { self .serialize (relation .other_locator )} "
289+ case _:
290+ raise ValueError (f"Unsupported relation type: { relation .type } " )
291+
292+ def _serialize_neighbor_relation (self , relation : NeighborRelation ) -> str :
293+ return f"index { relation .index } { self ._RELATION_TYPE_MAPPING [relation .type ]} intersection_area { self ._RP_TO_INTERSECTION_AREA_MAPPING [relation .reference_point ]} { self .serialize (relation .other_locator )} "
92294
93295
94296class VlmLocatorSerializer (LocatorSerializer [str ]):
95297 def serialize (self , locator : Locator ) -> str :
298+ if len (locator .relations ) > 0 :
299+ raise NotImplementedError (
300+ "Serializing locators with relations is not yet supported for VLMs"
301+ )
302+
96303 if isinstance (locator , Text ):
97304 return self ._serialize_text (locator )
98305 elif isinstance (locator , Class ):
@@ -107,7 +314,7 @@ def _serialize_class(self, class_: Class) -> str:
107314 return f"an arbitrary { class_ .class_name } shown"
108315 else :
109316 return "an arbitrary ui element (e.g., text, button, textfield, etc.)"
110-
317+
111318 def _serialize_description (self , description : Description ) -> str :
112319 return description .description
113320
0 commit comments