Skip to content

Commit 6087c8f

Browse files
committed
updated default agent with verbosity level
1 parent 782c9bf commit 6087c8f

1 file changed

Lines changed: 44 additions & 4 deletions

File tree

maseval/benchmark/tau2/tau2.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,7 @@ class DefaultTau2Agent:
497497
model: ModelAdapter for LLM calls
498498
llm_args: Additional arguments for LLM calls
499499
max_tool_calls: Maximum tool calls per turn (prevents infinite loops)
500+
verbose: Verbosity level (0=silent, 1=basic, 2=detailed)
500501
"""
501502

502503
def __init__(
@@ -506,6 +507,7 @@ def __init__(
506507
model: ModelAdapter,
507508
llm_args: Optional[Dict[str, Any]] = None,
508509
max_tool_calls: int = 50,
510+
verbose: int = 0,
509511
):
510512
"""Initialize the default tau2 agent.
511513
@@ -515,12 +517,17 @@ def __init__(
515517
model: ModelAdapter for making LLM calls
516518
llm_args: Optional additional arguments passed to model.generate()
517519
max_tool_calls: Maximum number of tool calls per agent turn
520+
verbose: Verbosity level for debugging output:
521+
- 0: Silent (no output)
522+
- 1: Basic (tool calls and responses)
523+
- 2: Detailed (full message contents, tool arguments and results)
518524
"""
519525
self.tools = tools
520526
self.policy = policy
521527
self.model = model
522528
self.llm_args = llm_args or {}
523529
self.max_tool_calls = max_tool_calls
530+
self.verbose = verbose
524531

525532
# Build system prompt
526533
self.system_prompt = _SYSTEM_PROMPT_TEMPLATE.format(
@@ -537,6 +544,16 @@ def reset(self) -> None:
537544
self._messages = []
538545
self._tool_call_count = 0
539546

547+
def _log(self, level: int, message: str) -> None:
548+
"""Print message if verbosity level is high enough.
549+
550+
Args:
551+
level: Minimum verbosity level required (1 or 2)
552+
message: Message to print
553+
"""
554+
if self.verbose >= level:
555+
print(message)
556+
540557
def run(self, query: str) -> str:
541558
"""Process a user query and return the agent's response.
542559
@@ -552,6 +569,8 @@ def run(self, query: str) -> str:
552569
Returns:
553570
Agent's text response to the user
554571
"""
572+
self._log(1, f"[Agent] Received query: {query[:100]}{'...' if len(query) > 100 else ''}")
573+
555574
# Add user message to history
556575
self._messages.append({"role": "user", "content": query})
557576

@@ -572,6 +591,7 @@ def _generate_with_tools(self) -> str:
572591
while self._tool_call_count < self.max_tool_calls:
573592
# Build messages for LLM call
574593
messages = [{"role": "system", "content": self.system_prompt}] + self._messages
594+
self._log(2, f"[Agent] Generating response (messages: {len(messages)}, tools: {len(self.tools)})")
575595

576596
# Generate response with tool access using chat() method
577597
response = self.model.chat(
@@ -585,6 +605,8 @@ def _generate_with_tools(self) -> str:
585605
tool_calls = response.tool_calls or []
586606

587607
if tool_calls:
608+
self._log(1, f"[Agent] Tool calls: {[self._get_tool_name(tc) for tc in tool_calls]}")
609+
588610
# Add assistant message with tool calls
589611
self._messages.append(
590612
{
@@ -612,26 +634,36 @@ def _generate_with_tools(self) -> str:
612634
continue
613635
else:
614636
# Text response - add to history and return
637+
self._log(1, f"[Agent] Response: {content[:100]}{'...' if len(content) > 100 else ''}")
615638
self._messages.append({"role": "assistant", "content": content})
616639
return content
617640

618641
# Max tool calls reached - return empty or error message
642+
self._log(1, f"[Agent] Max tool calls ({self.max_tool_calls}) reached")
619643
return "I apologize, but I've encountered an issue processing your request. Please try again."
620644

645+
def _get_tool_name(self, tool_call: Dict[str, Any]) -> str:
646+
"""Extract tool name from a tool call dict."""
647+
if "function" in tool_call:
648+
return tool_call["function"].get("name", "unknown")
649+
return tool_call.get("name", "unknown")
650+
621651
def _execute_tool_call(self, tool_call: Dict[str, Any]) -> Any:
622652
"""Execute a single tool call.
623653
624654
Args:
625-
tool_call: Dict with 'name' and 'arguments' keys
655+
tool_call: Dict in OpenAI format with 'function.name' and 'function.arguments',
656+
or flat format with 'name' and 'arguments' keys.
626657
627658
Returns:
628659
Tool execution result
629660
"""
630-
name = tool_call.get("name", "")
631-
# Handle both 'arguments' (dict) and 'function' (nested dict) formats
661+
# Handle both flat format and nested 'function' format (OpenAI/ChatResponse style)
632662
if "function" in tool_call:
663+
name = tool_call["function"].get("name", "")
633664
arguments = tool_call["function"].get("arguments", {})
634665
else:
666+
name = tool_call.get("name", "")
635667
arguments = tool_call.get("arguments", {})
636668

637669
# Handle string arguments (JSON encoded)
@@ -644,12 +676,17 @@ def _execute_tool_call(self, tool_call: Dict[str, Any]) -> Any:
644676
arguments = {}
645677

646678
if name not in self.tools:
679+
self._log(1, f"[Agent] Tool not found: {name}")
647680
return f"Error: Tool '{name}' not found"
648681

682+
self._log(2, f"[Agent] Executing {name}({arguments})")
649683
try:
650684
result = self.tools[name](**arguments)
685+
result_str = str(result)
686+
self._log(2, f"[Agent] Result: {result_str[:200]}{'...' if len(result_str) > 200 else ''}")
651687
return result
652688
except Exception as e:
689+
self._log(1, f"[Agent] Tool error: {name} - {e}")
653690
return f"Error executing tool '{name}': {str(e)}"
654691

655692
def _get_tool_definitions(self) -> List[Dict[str, Any]]:
@@ -801,6 +838,7 @@ class DefaultAgentTau2Benchmark(Tau2Benchmark):
801838
- model_id: LLM model identifier (required)
802839
- llm_args: Optional dict of additional LLM arguments
803840
- max_tool_calls: Maximum tool calls per turn (default: 50)
841+
- verbose: Verbosity level for debugging (0=silent, 1=basic, 2=detailed)
804842
805843
Example:
806844
from maseval.benchmark.tau2 import DefaultAgentTau2Benchmark, load_tasks, configure_model_ids
@@ -809,7 +847,7 @@ class DefaultAgentTau2Benchmark(Tau2Benchmark):
809847
configure_model_ids(tasks, user_model_id="gpt-4o")
810848
811849
benchmark = DefaultAgentTau2Benchmark(
812-
agent_data={"model_id": "gpt-4o"},
850+
agent_data={"model_id": "gpt-4o", "verbose": 1},
813851
)
814852
results = benchmark.run(tasks)
815853
"""
@@ -884,6 +922,7 @@ def setup_agents(
884922
model_id = self._get_agent_model_id(agent_data)
885923
llm_args = agent_data.get("llm_args", {})
886924
max_tool_calls = agent_data.get("max_tool_calls", 50)
925+
verbose = agent_data.get("verbose", 0)
887926

888927
# Get tools and policy from environment
889928
tools = environment.create_tools()
@@ -899,6 +938,7 @@ def setup_agents(
899938
model=model,
900939
llm_args=llm_args,
901940
max_tool_calls=max_tool_calls,
941+
verbose=verbose,
902942
)
903943

904944
# Wrap in adapter

0 commit comments

Comments
 (0)