Skip to content

Commit c25dc84

Browse files
committed
Implement VisualAgent and associated prompt flags for enhanced agent functionality
1 parent 1b1c4dc commit c25dc84

3 files changed

Lines changed: 69 additions & 121 deletions

File tree

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from agentlab.llm.llm_configs import CHAT_MODEL_ARGS_DICT
2+
3+
from .visual_agent import VisualAgentArgs
4+
from .visual_agent_prompts import PromptFlags
5+
import agentlab.agents.dynamic_prompting as dp
6+
import bgym
7+
8+
# the other flags are ignored for this agent.
9+
DEFAULT_OBS_FLAGS = dp.ObsFlags(
10+
use_tabs=True, # will be overridden by the benchmark when set_benchmark is called after initalizing the agent
11+
use_error_logs=True,
12+
use_past_error_logs=False,
13+
use_screenshot=True,
14+
use_som=False,
15+
openai_vision_detail="auto",
16+
)
17+
18+
DEFAULT_ACTION_FLAGS = dp.ActionFlags(
19+
action_set=bgym.HighLevelActionSetArgs(subsets=["coord"]),
20+
long_description=True,
21+
individual_examples=False,
22+
)
23+
24+
25+
DEFAULT_PROMPT_FLAGS = PromptFlags(
26+
obs=DEFAULT_OBS_FLAGS,
27+
action=DEFAULT_ACTION_FLAGS,
28+
use_thinking=True,
29+
use_concrete_example=False,
30+
use_abstract_example=True,
31+
enable_chat=False,
32+
extra_instructions=None,
33+
)
34+
35+
VISUAL_AGENT_4o = VisualAgentArgs(
36+
chat_model_args=CHAT_MODEL_ARGS_DICT["openai/gpt-4o-2024-05-13"],
37+
flags=DEFAULT_PROMPT_FLAGS,
38+
)
39+
40+
VISUAL_AGENT_COMPUTER_USE = VisualAgentArgs(
41+
chat_model_args=CHAT_MODEL_ARGS_DICT["openai/computer-use-preview-2025-03-11"],
42+
flags=DEFAULT_PROMPT_FLAGS,
43+
)
44+
45+
VISUAL_AGENT_CLAUDE_3_5 = VisualAgentArgs(
46+
chat_model_args=CHAT_MODEL_ARGS_DICT["openrouter/anthropic/claude-3.5-sonnet:beta"],
47+
flags=DEFAULT_PROMPT_FLAGS,
48+
)

src/agentlab/agents/visual_agent/visual_agent.py

Lines changed: 11 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@
88
the agent, including model arguments and flags for various behaviors.
99
"""
1010

11-
from copy import deepcopy
1211
from dataclasses import asdict, dataclass
13-
from warnings import warn
1412

1513
import bgym
1614
from browsergym.experiments.agent import Agent, AgentInfo
@@ -21,44 +19,24 @@
2119
from agentlab.llm.llm_utils import Discussion, ParseError, SystemMessage, retry
2220
from agentlab.llm.tracking import cost_tracker_decorator
2321

24-
from .visual_agent_prompts import GenericPromptFlags, MainPrompt
25-
from functools import partial
22+
from .visual_agent_prompts import PromptFlags, MainPrompt
2623

2724

2825
@dataclass
29-
class ToolAgentFlags:
30-
pass
31-
32-
33-
@dataclass
34-
class ToolAgentArgs(AgentArgs):
26+
class VisualAgentArgs(AgentArgs):
3527
chat_model_args: BaseModelArgs = None
36-
flags: GenericPromptFlags = None
28+
flags: PromptFlags = None
3729
max_retry: int = 4
3830

3931
def __post_init__(self):
40-
try: # some attributes might be temporarily args.CrossProd for hyperparameter generation
41-
self.agent_name = f"GenericAgent-{self.chat_model_args.model_name}".replace("/", "_")
32+
try: # some attributes might be missing temporarily due to args.CrossProd for hyperparameter generation
33+
self.agent_name = f"VisualAgent-{self.chat_model_args.model_name}".replace("/", "_")
4234
except AttributeError:
4335
pass
4436

4537
def set_benchmark(self, benchmark: bgym.Benchmark, demo_mode):
4638
"""Override Some flags based on the benchmark."""
47-
if benchmark.name.startswith("miniwob"):
48-
self.flags.obs.use_html = True
49-
5039
self.flags.obs.use_tabs = benchmark.is_multi_tab
51-
self.flags.action.action_set = deepcopy(benchmark.high_level_action_set_args)
52-
53-
# for backward compatibility with old traces
54-
if self.flags.action.multi_actions is not None:
55-
self.flags.action.action_set.multiaction = self.flags.action.multi_actions
56-
if self.flags.action.is_strict is not None:
57-
self.flags.action.action_set.strict = self.flags.action.is_strict
58-
59-
# verify if we can remove this
60-
if demo_mode:
61-
self.flags.action.action_set.demo_mode = "all_blue"
6240

6341
def set_reproducibility_mode(self):
6442
self.chat_model_args.temperature = 0
@@ -70,17 +48,17 @@ def close(self):
7048
return self.chat_model_args.close_server()
7149

7250
def make_agent(self):
73-
return ToolAgent(
51+
return VisualAgent(
7452
chat_model_args=self.chat_model_args, flags=self.flags, max_retry=self.max_retry
7553
)
7654

7755

78-
class ToolAgent(Agent):
56+
class VisualAgent(Agent):
7957

8058
def __init__(
8159
self,
8260
chat_model_args: BaseModelArgs,
83-
flags: GenericPromptFlags,
61+
flags: PromptFlags,
8462
max_retry: int = 4,
8563
):
8664

@@ -92,7 +70,6 @@ def __init__(
9270
self.action_set = self.flags.action.action_set.make_action_set()
9371
self._obs_preprocessor = dp.make_obs_preprocessor(flags.obs)
9472

95-
self._check_flag_constancy()
9673
self.reset(seed=None)
9774

9875
def obs_preprocessor(self, obs: dict) -> dict:
@@ -101,34 +78,20 @@ def obs_preprocessor(self, obs: dict) -> dict:
10178
@cost_tracker_decorator
10279
def get_action(self, obs):
10380

104-
self.obs_history.append(obs)
10581
main_prompt = MainPrompt(
10682
action_set=self.action_set,
107-
obs_history=self.obs_history,
83+
obs=obs,
10884
actions=self.actions,
109-
memories=self.memories,
11085
thoughts=self.thoughts,
111-
previous_plan=self.plan,
112-
step=self.plan_step,
11386
flags=self.flags,
11487
)
11588

116-
max_prompt_tokens, max_trunc_itr = self._get_maxes()
117-
11889
system_prompt = SystemMessage(dp.SystemPrompt().prompt)
119-
120-
human_prompt = dp.fit_tokens(
121-
shrinkable=main_prompt,
122-
max_prompt_tokens=max_prompt_tokens,
123-
model_name=self.chat_model_args.model_name,
124-
max_iterations=max_trunc_itr,
125-
additional_prompts=system_prompt,
126-
)
12790
try:
12891
# TODO, we would need to further shrink the prompt if the retry
12992
# cause it to be too long
13093

131-
chat_messages = Discussion([system_prompt, human_prompt])
94+
chat_messages = Discussion([system_prompt, main_prompt.prompt])
13295
ans_dict = retry(
13396
self.chat_llm,
13497
chat_messages,
@@ -138,7 +101,7 @@ def get_action(self, obs):
138101
ans_dict["busted_retry"] = 0
139102
# inferring the number of retries, TODO: make this less hacky
140103
ans_dict["n_retry"] = (len(chat_messages) - 3) / 2
141-
except ParseError as e:
104+
except ParseError:
142105
ans_dict = dict(
143106
action=None,
144107
n_retry=self.max_retry + 1,
@@ -149,10 +112,7 @@ def get_action(self, obs):
149112
stats["n_retry"] = ans_dict["n_retry"]
150113
stats["busted_retry"] = ans_dict["busted_retry"]
151114

152-
self.plan = ans_dict.get("plan", self.plan)
153-
self.plan_step = ans_dict.get("step", self.plan_step)
154115
self.actions.append(ans_dict["action"])
155-
self.memories.append(ans_dict.get("memory", None))
156116
self.thoughts.append(ans_dict.get("think", None))
157117

158118
agent_info = AgentInfo(
@@ -165,43 +125,5 @@ def get_action(self, obs):
165125

166126
def reset(self, seed=None):
167127
self.seed = seed
168-
self.plan = "No plan yet"
169-
self.plan_step = -1
170-
self.memories = []
171128
self.thoughts = []
172129
self.actions = []
173-
self.obs_history = []
174-
175-
def _check_flag_constancy(self):
176-
flags = self.flags
177-
if flags.obs.use_som:
178-
if not flags.obs.use_screenshot:
179-
warn(
180-
"""
181-
Warning: use_som=True requires use_screenshot=True. Disabling use_som."""
182-
)
183-
flags.obs.use_som = False
184-
if flags.obs.use_screenshot:
185-
if not self.chat_model_args.vision_support:
186-
warn(
187-
"""
188-
Warning: use_screenshot is set to True, but the chat model \
189-
does not support vision. Disabling use_screenshot."""
190-
)
191-
flags.obs.use_screenshot = False
192-
return flags
193-
194-
def _get_maxes(self):
195-
maxes = (
196-
self.flags.max_prompt_tokens,
197-
self.chat_model_args.max_total_tokens,
198-
self.chat_model_args.max_input_tokens,
199-
)
200-
maxes = [m for m in maxes if m is not None]
201-
max_prompt_tokens = min(maxes) if maxes else None
202-
max_trunc_itr = (
203-
self.flags.max_trunc_itr
204-
if self.flags.max_trunc_itr
205-
else 20 # dangerous to change the default value here?
206-
)
207-
return max_prompt_tokens, max_trunc_itr

src/agentlab/agents/visual_agent/visual_agent_prompts.py

Lines changed: 10 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import logging
88
from dataclasses import dataclass
9+
import bgym
910

1011
from browsergym.core.action.base import AbstractActionSet
1112

@@ -17,35 +18,15 @@
1718
class PromptFlags(dp.Flags):
1819
"""
1920
A class to represent various flags used to control features in an application.
20-
21-
Attributes:
22-
use_criticise (bool): Ask the LLM to first draft and criticise the action before producing it.
23-
use_thinking (bool): Enable a chain of thoughts.
24-
use_concrete_example (bool): Use a concrete example of the answer in the prompt for a generic task.
25-
use_abstract_example (bool): Use an abstract example of the answer in the prompt.
26-
use_hints (bool): Add some human-engineered hints to the prompt.
27-
enable_chat (bool): Enable chat mode, where the agent can interact with the user.
28-
max_prompt_tokens (int): Maximum number of tokens allowed in the prompt.
29-
be_cautious (bool): Instruct the agent to be cautious about its actions.
30-
extra_instructions (Optional[str]): Extra instructions to provide to the agent.
31-
add_missparsed_messages (bool): When retrying, add the missparsed messages to the prompt.
32-
flag_group (Optional[str]): Group of flags used.
3321
"""
3422

35-
obs: dp.ObsFlags
36-
action: dp.ActionFlags
37-
use_criticise: bool = False #
38-
use_thinking: bool = False
39-
use_concrete_example: bool = True
40-
use_abstract_example: bool = False
41-
use_hints: bool = False
23+
obs: dp.ObsFlags = None
24+
action: dp.ActionFlags = None
25+
use_thinking: bool = True
26+
use_concrete_example: bool = False
27+
use_abstract_example: bool = True
4228
enable_chat: bool = False
43-
max_prompt_tokens: int = None
44-
be_cautious: bool = True
4529
extra_instructions: str | None = None
46-
add_missparsed_messages: bool = True
47-
max_trunc_itr: int = 20
48-
flag_group: str = None
4930

5031

5132
class SystemPrompt(dp.PromptElement):
@@ -77,7 +58,7 @@ class History(dp.PromptElement):
7758
Format the actions and thoughts of previous steps."""
7859

7960
def __init__(self, actions, thoughts) -> None:
80-
61+
super().__init__()
8162
prompt_elements = []
8263
for i, (action, thought) in enumerate(zip(actions, thoughts)):
8364
prompt_elements.append(
@@ -121,7 +102,7 @@ def __init__(self, obs, flags: dp.ObsFlags) -> None:
121102
def _prompt(self) -> str:
122103
return f"""
123104
# Observation of current step:
124-
{self.tabs.prompt}{self.focused_element.prompt}{self.error.prompt}
105+
{self.tabs.prompt}{self.error.prompt}
125106
126107
"""
127108

@@ -152,12 +133,9 @@ def __init__(
152133
) -> None:
153134
super().__init__()
154135
self.flags = flags
155-
self.history = History(obs, actions, thoughts)
136+
self.history = History(actions, thoughts)
156137
self.instructions = make_instructions(obs, flags.enable_chat, flags.extra_instructions)
157-
self.obs = dp.Observation(
158-
obs,
159-
self.flags.obs,
160-
)
138+
self.obs = Observation(obs, self.flags.obs)
161139

162140
self.action_prompt = dp.ActionPrompt(action_set, action_flags=flags.action)
163141
self.think = dp.Think(visible=lambda: flags.use_thinking)

0 commit comments

Comments
 (0)