Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 207 additions & 0 deletions src/agentlab/agents/visual_agent/visual_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
"""
GenericAgent implementation for AgentLab

This module defines a `GenericAgent` class and its associated arguments for use in the AgentLab framework. \
The `GenericAgent` class is designed to interact with a chat-based model to determine actions based on \
observations. It includes methods for preprocessing observations, generating actions, and managing internal \
state such as plans, memories, and thoughts. The `GenericAgentArgs` class provides configuration options for \
the agent, including model arguments and flags for various behaviors.
"""

from copy import deepcopy
from dataclasses import asdict, dataclass
from warnings import warn

import bgym
from browsergym.experiments.agent import Agent, AgentInfo

from agentlab.agents import dynamic_prompting as dp
from agentlab.agents.agent_args import AgentArgs
from agentlab.llm.chat_api import BaseModelArgs
from agentlab.llm.llm_utils import Discussion, ParseError, SystemMessage, retry
from agentlab.llm.tracking import cost_tracker_decorator

from .visual_agent_prompts import GenericPromptFlags, MainPrompt
from functools import partial


@dataclass
class ToolAgentFlags:
pass

This comment was marked as resolved.



@dataclass
class ToolAgentArgs(AgentArgs):
chat_model_args: BaseModelArgs = None
flags: GenericPromptFlags = None
max_retry: int = 4

def __post_init__(self):
try: # some attributes might be temporarily args.CrossProd for hyperparameter generation
self.agent_name = f"GenericAgent-{self.chat_model_args.model_name}".replace("/", "_")
except AttributeError:
pass
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Silent AttributeError suppression category Error Handling

Tell me more
What is the issue?

A bare except-pass statement silently ignores AttributeError without any handling or logging.

Why this matters

Silent failure makes it difficult to identify configuration issues or invalid states during agent initialization.

Suggested change ∙ Feature Preview
try:
    self.agent_name = f"GenericAgent-{self.chat_model_args.model_name}".replace("/", "_")
except AttributeError:
    # Set a default name or log that model name was not available
    self.agent_name = "GenericAgent-unknown"
    warn("Could not set agent name with model name - using default name")
Provide feedback to improve future suggestions

Nice Catch Incorrect Not in Scope Not in coding standard Other

💬 Looking for more details? Reply to this comment to chat with Korbit.


def set_benchmark(self, benchmark: bgym.Benchmark, demo_mode):
"""Override Some flags based on the benchmark."""
if benchmark.name.startswith("miniwob"):
self.flags.obs.use_html = True

self.flags.obs.use_tabs = benchmark.is_multi_tab
self.flags.action.action_set = deepcopy(benchmark.high_level_action_set_args)

# for backward compatibility with old traces
if self.flags.action.multi_actions is not None:
self.flags.action.action_set.multiaction = self.flags.action.multi_actions
if self.flags.action.is_strict is not None:
self.flags.action.action_set.strict = self.flags.action.is_strict

# verify if we can remove this
if demo_mode:
self.flags.action.action_set.demo_mode = "all_blue"

def set_reproducibility_mode(self):
self.chat_model_args.temperature = 0

def prepare(self):
return self.chat_model_args.prepare_server()

def close(self):
return self.chat_model_args.close_server()

def make_agent(self):
return ToolAgent(
chat_model_args=self.chat_model_args, flags=self.flags, max_retry=self.max_retry
)


class ToolAgent(Agent):

def __init__(
self,
chat_model_args: BaseModelArgs,
flags: GenericPromptFlags,
max_retry: int = 4,
):

self.chat_llm = chat_model_args.make_model()
self.chat_model_args = chat_model_args
self.max_retry = max_retry

self.flags = flags
self.action_set = self.flags.action.action_set.make_action_set()
self._obs_preprocessor = dp.make_obs_preprocessor(flags.obs)

self._check_flag_constancy()
self.reset(seed=None)

def obs_preprocessor(self, obs: dict) -> dict:
return self._obs_preprocessor(obs)

@cost_tracker_decorator
def get_action(self, obs):

self.obs_history.append(obs)

This comment was marked as resolved.

main_prompt = MainPrompt(
Comment on lines +102 to +81

This comment was marked as resolved.

action_set=self.action_set,
obs_history=self.obs_history,
actions=self.actions,
memories=self.memories,
thoughts=self.thoughts,
previous_plan=self.plan,
step=self.plan_step,
flags=self.flags,
)

max_prompt_tokens, max_trunc_itr = self._get_maxes()

system_prompt = SystemMessage(dp.SystemPrompt().prompt)

human_prompt = dp.fit_tokens(
shrinkable=main_prompt,
max_prompt_tokens=max_prompt_tokens,
model_name=self.chat_model_args.model_name,
max_iterations=max_trunc_itr,
additional_prompts=system_prompt,
)
try:
# TODO, we would need to further shrink the prompt if the retry
# cause it to be too long
Comment on lines +91 to +92
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-actionable TODO comment category Documentation

Tell me more
What is the issue?

TODO comment is not actionable and lacks context about implementation details.

Why this matters

Future developers won't understand what specifically needs to be implemented or why.

Suggested change ∙ Feature Preview

TODO: Implement dynamic prompt shrinking when retries cause token limits to be exceeded.

        # Need to track token count during retries and adjust max_prompt_tokens accordingly.
Provide feedback to improve future suggestions

Nice Catch Incorrect Not in Scope Not in coding standard Other

💬 Looking for more details? Reply to this comment to chat with Korbit.


chat_messages = Discussion([system_prompt, human_prompt])
ans_dict = retry(
self.chat_llm,
chat_messages,
n_retry=self.max_retry,
parser=main_prompt._parse_answer,
)
ans_dict["busted_retry"] = 0
# inferring the number of retries, TODO: make this less hacky
ans_dict["n_retry"] = (len(chat_messages) - 3) / 2
Comment on lines +102 to +103
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Complex retry calculation with magic numbers category Readability

Tell me more
What is the issue?

Complex calculation with magic numbers (3, 2) used to infer retry count.

Why this matters

The formula's intent is unclear and forces readers to reverse engineer the logic.

Suggested change ∙ Feature Preview

Extract the calculation into a well-named method like _calculate_retry_count() that explains the logic and uses named constants for the magic numbers.

Provide feedback to improve future suggestions

Nice Catch Incorrect Not in Scope Not in coding standard Other

💬 Looking for more details? Reply to this comment to chat with Korbit.

except ParseError as e:
ans_dict = dict(
action=None,
n_retry=self.max_retry + 1,
busted_retry=1,
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ParseError details lost category Error Handling

Tell me more
What is the issue?

The ParseError exception is caught but the error information (e) is not logged or used, resulting in loss of valuable debugging context.

Why this matters

Without capturing the error details, debugging production issues becomes more difficult as the root cause of parsing failures cannot be traced.

Suggested change ∙ Feature Preview
except ParseError as e:
    ans_dict = dict(
        action=None,
        n_retry=self.max_retry + 1,
        busted_retry=1,
        error_details=str(e)  # Preserve error information
    )
Provide feedback to improve future suggestions

Nice Catch Incorrect Not in Scope Not in coding standard Other

💬 Looking for more details? Reply to this comment to chat with Korbit.


stats = self.chat_llm.get_stats()
stats["n_retry"] = ans_dict["n_retry"]
stats["busted_retry"] = ans_dict["busted_retry"]

self.plan = ans_dict.get("plan", self.plan)
self.plan_step = ans_dict.get("step", self.plan_step)
self.actions.append(ans_dict["action"])
self.memories.append(ans_dict.get("memory", None))
self.thoughts.append(ans_dict.get("think", None))

agent_info = AgentInfo(
think=ans_dict.get("think", None),
chat_messages=chat_messages,
stats=stats,
extra_info={"chat_model_args": asdict(self.chat_model_args)},
)
return ans_dict["action"], agent_info

def reset(self, seed=None):
self.seed = seed
self.plan = "No plan yet"
self.plan_step = -1
self.memories = []
self.thoughts = []
self.actions = []
self.obs_history = []

def _check_flag_constancy(self):
flags = self.flags
if flags.obs.use_som:
if not flags.obs.use_screenshot:
warn(
"""
Warning: use_som=True requires use_screenshot=True. Disabling use_som."""
)
flags.obs.use_som = False
if flags.obs.use_screenshot:
if not self.chat_model_args.vision_support:
warn(
"""
Warning: use_screenshot is set to True, but the chat model \
does not support vision. Disabling use_screenshot."""
)
flags.obs.use_screenshot = False
return flags

def _get_maxes(self):
maxes = (
self.flags.max_prompt_tokens,
self.chat_model_args.max_total_tokens,
self.chat_model_args.max_input_tokens,
)
maxes = [m for m in maxes if m is not None]
max_prompt_tokens = min(maxes) if maxes else None
max_trunc_itr = (
self.flags.max_trunc_itr
if self.flags.max_trunc_itr
else 20 # dangerous to change the default value here?
)
return max_prompt_tokens, max_trunc_itr
Loading