88 the agent, including model arguments and flags for various behaviors.
99"""
1010
11- from copy import deepcopy
1211from dataclasses import asdict , dataclass
13- from warnings import warn
1412
1513import bgym
1614from browsergym .experiments .agent import Agent , AgentInfo
2119from agentlab .llm .llm_utils import Discussion , ParseError , SystemMessage , retry
2220from agentlab .llm .tracking import cost_tracker_decorator
2321
24- from .visual_agent_prompts import GenericPromptFlags , MainPrompt
25- from functools import partial
22+ from .visual_agent_prompts import PromptFlags , MainPrompt
2623
2724
2825@dataclass
29- class ToolAgentFlags :
30- pass
31-
32-
33- @dataclass
34- class ToolAgentArgs (AgentArgs ):
26+ class VisualAgentArgs (AgentArgs ):
3527 chat_model_args : BaseModelArgs = None
36- flags : GenericPromptFlags = None
28+ flags : PromptFlags = None
3729 max_retry : int = 4
3830
3931 def __post_init__ (self ):
40- try : # some attributes might be temporarily args.CrossProd for hyperparameter generation
41- self .agent_name = f"GenericAgent -{ self .chat_model_args .model_name } " .replace ("/" , "_" )
32+ try : # some attributes might be missing temporarily due to args.CrossProd for hyperparameter generation
33+ self .agent_name = f"VisualAgent -{ self .chat_model_args .model_name } " .replace ("/" , "_" )
4234 except AttributeError :
4335 pass
4436
4537 def set_benchmark (self , benchmark : bgym .Benchmark , demo_mode ):
4638 """Override Some flags based on the benchmark."""
47- if benchmark .name .startswith ("miniwob" ):
48- self .flags .obs .use_html = True
49-
5039 self .flags .obs .use_tabs = benchmark .is_multi_tab
51- self .flags .action .action_set = deepcopy (benchmark .high_level_action_set_args )
52-
53- # for backward compatibility with old traces
54- if self .flags .action .multi_actions is not None :
55- self .flags .action .action_set .multiaction = self .flags .action .multi_actions
56- if self .flags .action .is_strict is not None :
57- self .flags .action .action_set .strict = self .flags .action .is_strict
58-
59- # verify if we can remove this
60- if demo_mode :
61- self .flags .action .action_set .demo_mode = "all_blue"
6240
6341 def set_reproducibility_mode (self ):
6442 self .chat_model_args .temperature = 0
@@ -70,17 +48,17 @@ def close(self):
7048 return self .chat_model_args .close_server ()
7149
7250 def make_agent (self ):
73- return ToolAgent (
51+ return VisualAgent (
7452 chat_model_args = self .chat_model_args , flags = self .flags , max_retry = self .max_retry
7553 )
7654
7755
78- class ToolAgent (Agent ):
56+ class VisualAgent (Agent ):
7957
8058 def __init__ (
8159 self ,
8260 chat_model_args : BaseModelArgs ,
83- flags : GenericPromptFlags ,
61+ flags : PromptFlags ,
8462 max_retry : int = 4 ,
8563 ):
8664
@@ -92,7 +70,6 @@ def __init__(
9270 self .action_set = self .flags .action .action_set .make_action_set ()
9371 self ._obs_preprocessor = dp .make_obs_preprocessor (flags .obs )
9472
95- self ._check_flag_constancy ()
9673 self .reset (seed = None )
9774
9875 def obs_preprocessor (self , obs : dict ) -> dict :
@@ -101,34 +78,20 @@ def obs_preprocessor(self, obs: dict) -> dict:
10178 @cost_tracker_decorator
10279 def get_action (self , obs ):
10380
104- self .obs_history .append (obs )
10581 main_prompt = MainPrompt (
10682 action_set = self .action_set ,
107- obs_history = self . obs_history ,
83+ obs = obs ,
10884 actions = self .actions ,
109- memories = self .memories ,
11085 thoughts = self .thoughts ,
111- previous_plan = self .plan ,
112- step = self .plan_step ,
11386 flags = self .flags ,
11487 )
11588
116- max_prompt_tokens , max_trunc_itr = self ._get_maxes ()
117-
11889 system_prompt = SystemMessage (dp .SystemPrompt ().prompt )
119-
120- human_prompt = dp .fit_tokens (
121- shrinkable = main_prompt ,
122- max_prompt_tokens = max_prompt_tokens ,
123- model_name = self .chat_model_args .model_name ,
124- max_iterations = max_trunc_itr ,
125- additional_prompts = system_prompt ,
126- )
12790 try :
12891 # TODO, we would need to further shrink the prompt if the retry
12992 # cause it to be too long
13093
131- chat_messages = Discussion ([system_prompt , human_prompt ])
94+ chat_messages = Discussion ([system_prompt , main_prompt . prompt ])
13295 ans_dict = retry (
13396 self .chat_llm ,
13497 chat_messages ,
@@ -138,7 +101,7 @@ def get_action(self, obs):
138101 ans_dict ["busted_retry" ] = 0
139102 # inferring the number of retries, TODO: make this less hacky
140103 ans_dict ["n_retry" ] = (len (chat_messages ) - 3 ) / 2
141- except ParseError as e :
104+ except ParseError :
142105 ans_dict = dict (
143106 action = None ,
144107 n_retry = self .max_retry + 1 ,
@@ -149,10 +112,7 @@ def get_action(self, obs):
149112 stats ["n_retry" ] = ans_dict ["n_retry" ]
150113 stats ["busted_retry" ] = ans_dict ["busted_retry" ]
151114
152- self .plan = ans_dict .get ("plan" , self .plan )
153- self .plan_step = ans_dict .get ("step" , self .plan_step )
154115 self .actions .append (ans_dict ["action" ])
155- self .memories .append (ans_dict .get ("memory" , None ))
156116 self .thoughts .append (ans_dict .get ("think" , None ))
157117
158118 agent_info = AgentInfo (
@@ -165,43 +125,5 @@ def get_action(self, obs):
165125
166126 def reset (self , seed = None ):
167127 self .seed = seed
168- self .plan = "No plan yet"
169- self .plan_step = - 1
170- self .memories = []
171128 self .thoughts = []
172129 self .actions = []
173- self .obs_history = []
174-
175- def _check_flag_constancy (self ):
176- flags = self .flags
177- if flags .obs .use_som :
178- if not flags .obs .use_screenshot :
179- warn (
180- """
181- Warning: use_som=True requires use_screenshot=True. Disabling use_som."""
182- )
183- flags .obs .use_som = False
184- if flags .obs .use_screenshot :
185- if not self .chat_model_args .vision_support :
186- warn (
187- """
188- Warning: use_screenshot is set to True, but the chat model \
189- does not support vision. Disabling use_screenshot."""
190- )
191- flags .obs .use_screenshot = False
192- return flags
193-
194- def _get_maxes (self ):
195- maxes = (
196- self .flags .max_prompt_tokens ,
197- self .chat_model_args .max_total_tokens ,
198- self .chat_model_args .max_input_tokens ,
199- )
200- maxes = [m for m in maxes if m is not None ]
201- max_prompt_tokens = min (maxes ) if maxes else None
202- max_trunc_itr = (
203- self .flags .max_trunc_itr
204- if self .flags .max_trunc_itr
205- else 20 # dangerous to change the default value here?
206- )
207- return max_prompt_tokens , max_trunc_itr
0 commit comments