44import requests
55
66from PIL import Image
7- from typing import List , Union
7+ from typing import Any , List , Union
88from askui .models .askui .ai_element_utils import AiElement , AiElementCollection , AiElementNotFound
9+ from askui .models .locators import AskUiLocatorSerializer , Locator
910from askui .utils import image_to_base64
1011from askui .logger import logger
1112
@@ -23,6 +24,7 @@ def __init__(self):
2324 self .authenticated = False
2425
2526 self .ai_element_collection = AiElementCollection ()
27+ self ._locator_serializer = AskUiLocatorSerializer ()
2628
2729
2830
@@ -32,7 +34,7 @@ def _build_askui_token_auth_header(self, bearer_token: str | None = None) -> dic
3234 token_base64 = base64 .b64encode (self .token .encode ("utf-8" )).decode ("utf-8" )
3335 return {"Authorization" : f"Basic { token_base64 } " }
3436
35- def _build_custom_elements (self , ai_elements : List [AiElement ] | None ):
37+ def _build_custom_elements (self , ai_elements : List [AiElement ] | None ) -> list [ dict [ str , str ]] :
3638 """
3739 Converts AiElements to the CustomElementDto format expected by the backend.
3840
@@ -43,9 +45,9 @@ def _build_custom_elements(self, ai_elements: List[AiElement] | None):
4345 dict: Custom elements in the format expected by the backend
4446 """
4547 if not ai_elements :
46- return {}
48+ return []
4749
48- custom_elements = []
50+ custom_elements : list [ dict [ str , str ]] = []
4951 for element in ai_elements :
5052 custom_element = {
5153 "customImage" : "," + image_to_base64 (element .image ),
@@ -54,24 +56,22 @@ def _build_custom_elements(self, ai_elements: List[AiElement] | None):
5456 }
5557 custom_elements .append (custom_element )
5658
57- return {
58- "customElements" : custom_elements
59- }
60- def __build_model_composition (self ):
61- return {}
59+ return custom_elements
6260
6361 def __build_base_url (self , endpoint : str = "inference" ) -> str :
6462 return f"{ self .inference_endpoint } /api/v3/workspaces/{ self .workspace_id } /{ endpoint } "
6563
66- def predict (self , image : Union [pathlib .Path , Image .Image ], locator : str , ai_elements : List [pathlib .Path ] = None ) -> tuple [int | None , int | None ]:
64+ def predict (self , image : Union [pathlib .Path , Image .Image ], locator : str | Locator , ai_elements : List [AiElement ] | None = None ) -> tuple [int | None , int | None ]:
65+ json : dict [str , Any ] = {
66+ "image" : f",{ image_to_base64 (image )} " ,
67+ }
68+ if locator is not None :
69+ json ["instruction" ] = locator if isinstance (locator , str ) else locator .serialize (serializer = self ._locator_serializer )
70+ if ai_elements is not None :
71+ json ["customElements" ] = self ._build_custom_elements (ai_elements )
6772 response = requests .post (
6873 self .__build_base_url (),
69- json = {
70- "image" : f",{ image_to_base64 (image )} " ,
71- ** ({"instruction" : locator } if locator is not None else {}),
72- ** self .__build_model_composition (),
73- ** self ._build_custom_elements (ai_elements )
74- },
74+ json = json ,
7575 headers = {"Content-Type" : "application/json" , ** self ._build_askui_token_auth_header ()},
7676 timeout = 30 ,
7777 )
@@ -83,23 +83,23 @@ def predict(self, image: Union[pathlib.Path, Image.Image], locator: str, ai_elem
8383 actions = [el for el in content ["data" ]["actions" ] if el ["inputEvent" ] == "MOUSE_MOVE" ]
8484 if len (actions ) == 0 :
8585 return None , None
86- position = actions [0 ]["position" ]
8786
87+ position = actions [0 ]["position" ]
8888 return int (position ["x" ]), int (position ["y" ])
8989
90- def locate_pta_prediction (self , image : Union [pathlib .Path , Image .Image ], locator : str ) -> tuple [int | None , int | None ]:
91- askui_locator = f'Click on pta "{ locator } "'
92- return self .predict (image , askui_locator )
90+ def locate_pta_prediction (self , image : Union [pathlib .Path , Image .Image ], locator : str | Locator ) -> tuple [int | None , int | None ]:
91+ _locator = f'Click on pta "{ locator } "' if isinstance ( locator , str ) else locator
92+ return self .predict (image , _locator )
9393
94- def locate_ocr_prediction (self , image : Union [pathlib .Path , Image .Image ], locator : str ) -> tuple [int | None , int | None ]:
95- askui_locator = f'Click on with text "{ locator } "'
96- return self .predict (image , askui_locator )
94+ def locate_ocr_prediction (self , image : Union [pathlib .Path , Image .Image ], locator : str | Locator ) -> tuple [int | None , int | None ]:
95+ _locator = f'Click on with text "{ locator } "' if isinstance ( locator , str ) else locator
96+ return self .predict (image , _locator )
9797
9898 def locate_ai_element_prediction (self , image : Union [pathlib .Path , Image .Image ], name : str ) -> tuple [int | None , int | None ]:
9999 ai_elements = self .ai_element_collection .find (name )
100100
101101 if len (ai_elements ) == 0 :
102102 raise AiElementNotFound (f"Could not locate AI element with name '{ name } '" )
103103
104- askui_instruction = f'Click on custom element with text "{ name } "'
105- return self .predict (image , askui_instruction , ai_elements = ai_elements )
104+ _locator = f'Click on custom element with text "{ name } "'
105+ return self .predict (image , _locator , ai_elements = ai_elements )
0 commit comments