11from agentlab .llm .llm_utils import ParseError , retry
22from agentlab .llm .tracking import cost_tracker_decorator
3+ from browsergym .core .action .highlevel import HighLevelActionSet
34from browsergym .experiments .agent import AgentInfo
45from browsergym .experiments .benchmark import Benchmark
56from browsergym .experiments .benchmark .base import HighLevelActionSetArgs
67from browsergym .utils .obs import overlay_som
78from copy import copy , deepcopy
89from dataclasses import asdict , dataclass
10+ from functools import cache
911from typing import Optional
1012from .base import VLAgent , VLAgentArgs
1113from ..vl_model .base import VLModelArgs
@@ -17,21 +19,26 @@ def __init__(
1719 self ,
1820 main_vl_model_args : VLModelArgs ,
1921 auxiliary_vl_model_args : Optional [VLModelArgs ],
20- action_set_args : HighLevelActionSetArgs ,
2122 ui_prompt_args : UIPromptArgs ,
23+ action_set_args : HighLevelActionSetArgs ,
2224 max_retry : int ,
2325 ):
2426 self .main_vl_model = main_vl_model_args .make_model ()
2527 if auxiliary_vl_model_args is None :
2628 self .auxiliary_vl_model = None
2729 else :
2830 self .auxiliary_vl_model = auxiliary_vl_model_args .make_model ()
29- self .action_set = action_set_args .make_action_set ()
3031 self .ui_prompt_args = ui_prompt_args
32+ self .action_set_args = action_set_args
3133 self .max_retry = max_retry
3234 self .thoughts = []
3335 self .actions = []
3436
37+ @property
38+ @cache
39+ def action_set (self ) -> HighLevelActionSet :
40+ return self .action_set_args .make_action_set ()
41+
3542 @cost_tracker_decorator
3643 def get_action (self , obs : dict ) -> tuple [str , dict ]:
3744 ui_prompt = self .ui_prompt_args .make_prompt (
@@ -94,25 +101,27 @@ def obs_preprocessor(self, obs: dict) -> dict:
94101class UIAgentArgs (VLAgentArgs ):
95102 main_vl_model_args : VLModelArgs
96103 auxiliary_vl_model_args : Optional [VLModelArgs ]
97- action_set_args : HighLevelActionSetArgs
98104 ui_prompt_args : UIPromptArgs
105+ action_set_args : HighLevelActionSetArgs
99106 max_retry : int
100107
101108 @property
109+ @cache
102110 def agent_name (self ) -> str :
103111 if self .auxiliary_vl_model_args is None :
104112 return f"UIAgent-{ self .main_vl_model_args .model_name } "
105113 else :
106114 return f"UIAgent-{ self .main_vl_model_args .model_name } -{ self .auxiliary_vl_model_args .model_name } "
107115
108116 def make_agent (self ) -> UIAgent :
109- return UIAgent (
117+ self . ui_agent = UIAgent (
110118 main_vl_model_args = self .main_vl_model_args ,
111119 auxiliary_vl_model_args = self .auxiliary_vl_model_args ,
112- action_set_args = self .action_set_args ,
113120 ui_prompt_args = self .ui_prompt_args ,
121+ action_set_args = self .action_set_args ,
114122 max_retry = self .max_retry ,
115123 )
124+ return self .ui_agent
116125
117126 def prepare (self ):
118127 self .main_vl_model_args .prepare ()
0 commit comments