-
Notifications
You must be signed in to change notification settings - Fork 115
Create a simple pure visual agent. #235
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
1b1c4dc
c25dc84
ba8c91e
a5a8ef4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,207 @@ | ||
| """ | ||
| GenericAgent implementation for AgentLab | ||
|
|
||
| This module defines a `GenericAgent` class and its associated arguments for use in the AgentLab framework. \ | ||
| The `GenericAgent` class is designed to interact with a chat-based model to determine actions based on \ | ||
| observations. It includes methods for preprocessing observations, generating actions, and managing internal \ | ||
| state such as plans, memories, and thoughts. The `GenericAgentArgs` class provides configuration options for \ | ||
| the agent, including model arguments and flags for various behaviors. | ||
| """ | ||
|
|
||
| from copy import deepcopy | ||
| from dataclasses import asdict, dataclass | ||
| from warnings import warn | ||
|
|
||
| import bgym | ||
| from browsergym.experiments.agent import Agent, AgentInfo | ||
|
|
||
| from agentlab.agents import dynamic_prompting as dp | ||
| from agentlab.agents.agent_args import AgentArgs | ||
| from agentlab.llm.chat_api import BaseModelArgs | ||
| from agentlab.llm.llm_utils import Discussion, ParseError, SystemMessage, retry | ||
| from agentlab.llm.tracking import cost_tracker_decorator | ||
|
|
||
| from .visual_agent_prompts import GenericPromptFlags, MainPrompt | ||
| from functools import partial | ||
|
|
||
|
|
||
| @dataclass | ||
| class ToolAgentFlags: | ||
| pass | ||
|
|
||
|
|
||
| @dataclass | ||
| class ToolAgentArgs(AgentArgs): | ||
| chat_model_args: BaseModelArgs = None | ||
| flags: GenericPromptFlags = None | ||
| max_retry: int = 4 | ||
|
|
||
| def __post_init__(self): | ||
| try: # some attributes might be temporarily args.CrossProd for hyperparameter generation | ||
| self.agent_name = f"GenericAgent-{self.chat_model_args.model_name}".replace("/", "_") | ||
| except AttributeError: | ||
| pass | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Silent AttributeError suppression
Tell me moreWhat is the issue?A bare except-pass statement silently ignores AttributeError without any handling or logging. Why this mattersSilent failure makes it difficult to identify configuration issues or invalid states during agent initialization. Suggested change ∙ Feature Previewtry:
self.agent_name = f"GenericAgent-{self.chat_model_args.model_name}".replace("/", "_")
except AttributeError:
# Set a default name or log that model name was not available
self.agent_name = "GenericAgent-unknown"
warn("Could not set agent name with model name - using default name")Provide feedback to improve future suggestions💬 Looking for more details? Reply to this comment to chat with Korbit. |
||
|
|
||
| def set_benchmark(self, benchmark: bgym.Benchmark, demo_mode): | ||
| """Override Some flags based on the benchmark.""" | ||
| if benchmark.name.startswith("miniwob"): | ||
| self.flags.obs.use_html = True | ||
|
|
||
| self.flags.obs.use_tabs = benchmark.is_multi_tab | ||
| self.flags.action.action_set = deepcopy(benchmark.high_level_action_set_args) | ||
|
|
||
| # for backward compatibility with old traces | ||
| if self.flags.action.multi_actions is not None: | ||
| self.flags.action.action_set.multiaction = self.flags.action.multi_actions | ||
| if self.flags.action.is_strict is not None: | ||
| self.flags.action.action_set.strict = self.flags.action.is_strict | ||
|
|
||
| # verify if we can remove this | ||
| if demo_mode: | ||
| self.flags.action.action_set.demo_mode = "all_blue" | ||
|
|
||
| def set_reproducibility_mode(self): | ||
| self.chat_model_args.temperature = 0 | ||
|
|
||
| def prepare(self): | ||
| return self.chat_model_args.prepare_server() | ||
|
|
||
| def close(self): | ||
| return self.chat_model_args.close_server() | ||
|
|
||
| def make_agent(self): | ||
| return ToolAgent( | ||
| chat_model_args=self.chat_model_args, flags=self.flags, max_retry=self.max_retry | ||
| ) | ||
|
|
||
|
|
||
| class ToolAgent(Agent): | ||
|
|
||
| def __init__( | ||
| self, | ||
| chat_model_args: BaseModelArgs, | ||
| flags: GenericPromptFlags, | ||
| max_retry: int = 4, | ||
| ): | ||
|
|
||
| self.chat_llm = chat_model_args.make_model() | ||
| self.chat_model_args = chat_model_args | ||
| self.max_retry = max_retry | ||
|
|
||
| self.flags = flags | ||
| self.action_set = self.flags.action.action_set.make_action_set() | ||
| self._obs_preprocessor = dp.make_obs_preprocessor(flags.obs) | ||
|
|
||
| self._check_flag_constancy() | ||
| self.reset(seed=None) | ||
|
|
||
| def obs_preprocessor(self, obs: dict) -> dict: | ||
| return self._obs_preprocessor(obs) | ||
|
|
||
| @cost_tracker_decorator | ||
| def get_action(self, obs): | ||
|
|
||
| self.obs_history.append(obs) | ||
This comment was marked as resolved.
Sorry, something went wrong. |
||
| main_prompt = MainPrompt( | ||
|
Comment on lines
+102
to
+81
This comment was marked as resolved.
Sorry, something went wrong. |
||
| action_set=self.action_set, | ||
| obs_history=self.obs_history, | ||
| actions=self.actions, | ||
| memories=self.memories, | ||
| thoughts=self.thoughts, | ||
| previous_plan=self.plan, | ||
| step=self.plan_step, | ||
| flags=self.flags, | ||
| ) | ||
|
|
||
| max_prompt_tokens, max_trunc_itr = self._get_maxes() | ||
|
|
||
| system_prompt = SystemMessage(dp.SystemPrompt().prompt) | ||
|
|
||
| human_prompt = dp.fit_tokens( | ||
| shrinkable=main_prompt, | ||
| max_prompt_tokens=max_prompt_tokens, | ||
| model_name=self.chat_model_args.model_name, | ||
| max_iterations=max_trunc_itr, | ||
| additional_prompts=system_prompt, | ||
| ) | ||
| try: | ||
| # TODO, we would need to further shrink the prompt if the retry | ||
| # cause it to be too long | ||
|
Comment on lines
+91
to
+92
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Non-actionable TODO comment
Tell me moreWhat is the issue?TODO comment is not actionable and lacks context about implementation details. Why this mattersFuture developers won't understand what specifically needs to be implemented or why. Suggested change ∙ Feature PreviewTODO: Implement dynamic prompt shrinking when retries cause token limits to be exceeded.Provide feedback to improve future suggestions💬 Looking for more details? Reply to this comment to chat with Korbit. |
||
|
|
||
| chat_messages = Discussion([system_prompt, human_prompt]) | ||
| ans_dict = retry( | ||
| self.chat_llm, | ||
| chat_messages, | ||
| n_retry=self.max_retry, | ||
| parser=main_prompt._parse_answer, | ||
| ) | ||
| ans_dict["busted_retry"] = 0 | ||
| # inferring the number of retries, TODO: make this less hacky | ||
| ans_dict["n_retry"] = (len(chat_messages) - 3) / 2 | ||
|
Comment on lines
+102
to
+103
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Complex retry calculation with magic numbers
Tell me moreWhat is the issue?Complex calculation with magic numbers (3, 2) used to infer retry count. Why this mattersThe formula's intent is unclear and forces readers to reverse engineer the logic. Suggested change ∙ Feature PreviewExtract the calculation into a well-named method like Provide feedback to improve future suggestions💬 Looking for more details? Reply to this comment to chat with Korbit. |
||
| except ParseError as e: | ||
| ans_dict = dict( | ||
| action=None, | ||
| n_retry=self.max_retry + 1, | ||
| busted_retry=1, | ||
| ) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ParseError details lost
Tell me moreWhat is the issue?The ParseError exception is caught but the error information (e) is not logged or used, resulting in loss of valuable debugging context. Why this mattersWithout capturing the error details, debugging production issues becomes more difficult as the root cause of parsing failures cannot be traced. Suggested change ∙ Feature Previewexcept ParseError as e:
ans_dict = dict(
action=None,
n_retry=self.max_retry + 1,
busted_retry=1,
error_details=str(e) # Preserve error information
)Provide feedback to improve future suggestions💬 Looking for more details? Reply to this comment to chat with Korbit. |
||
|
|
||
| stats = self.chat_llm.get_stats() | ||
| stats["n_retry"] = ans_dict["n_retry"] | ||
| stats["busted_retry"] = ans_dict["busted_retry"] | ||
|
|
||
| self.plan = ans_dict.get("plan", self.plan) | ||
| self.plan_step = ans_dict.get("step", self.plan_step) | ||
| self.actions.append(ans_dict["action"]) | ||
| self.memories.append(ans_dict.get("memory", None)) | ||
| self.thoughts.append(ans_dict.get("think", None)) | ||
|
|
||
| agent_info = AgentInfo( | ||
| think=ans_dict.get("think", None), | ||
| chat_messages=chat_messages, | ||
| stats=stats, | ||
| extra_info={"chat_model_args": asdict(self.chat_model_args)}, | ||
| ) | ||
| return ans_dict["action"], agent_info | ||
|
|
||
| def reset(self, seed=None): | ||
| self.seed = seed | ||
| self.plan = "No plan yet" | ||
| self.plan_step = -1 | ||
| self.memories = [] | ||
| self.thoughts = [] | ||
| self.actions = [] | ||
| self.obs_history = [] | ||
|
|
||
| def _check_flag_constancy(self): | ||
| flags = self.flags | ||
| if flags.obs.use_som: | ||
| if not flags.obs.use_screenshot: | ||
| warn( | ||
| """ | ||
| Warning: use_som=True requires use_screenshot=True. Disabling use_som.""" | ||
| ) | ||
| flags.obs.use_som = False | ||
| if flags.obs.use_screenshot: | ||
| if not self.chat_model_args.vision_support: | ||
| warn( | ||
| """ | ||
| Warning: use_screenshot is set to True, but the chat model \ | ||
| does not support vision. Disabling use_screenshot.""" | ||
| ) | ||
| flags.obs.use_screenshot = False | ||
| return flags | ||
|
|
||
| def _get_maxes(self): | ||
| maxes = ( | ||
| self.flags.max_prompt_tokens, | ||
| self.chat_model_args.max_total_tokens, | ||
| self.chat_model_args.max_input_tokens, | ||
| ) | ||
| maxes = [m for m in maxes if m is not None] | ||
| max_prompt_tokens = min(maxes) if maxes else None | ||
| max_trunc_itr = ( | ||
| self.flags.max_trunc_itr | ||
| if self.flags.max_trunc_itr | ||
| else 20 # dangerous to change the default value here? | ||
| ) | ||
| return max_prompt_tokens, max_trunc_itr | ||
This comment was marked as resolved.
Sorry, something went wrong.
Uh oh!
There was an error while loading. Please reload this page.