Skip to content

Commit 782c9bf

Browse files
committed
fixed tau2 and agenticuser tests given new ModelAdapter chat features
1 parent 91de571 commit 782c9bf

5 files changed

Lines changed: 217 additions & 197 deletions

File tree

examples/tau2_benchmark/tau2_default_agent_benchmark.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,10 @@
4949
load_tasks,
5050
)
5151

52-
# Import a ModelAdapter - using Google GenAI Tool Calling adapter
53-
# You can substitute with OpenAI, Anthropic, or any other tool-calling capable ModelAdapter
52+
# Import a ModelAdapter - using Google GenAI adapter (has built-in tool calling support)
53+
# You can substitute with OpenAI, Anthropic, LiteLLM, or any other ModelAdapter
5454
from google.genai import Client as GoogleGenAIClient
55-
from maseval.interface.inference.google_genai_tool_calling import ToolCallingGoogleGenAIAdapter
55+
from maseval.interface.inference import GoogleGenAIModelAdapter
5656

5757

5858
# =============================================================================
@@ -140,17 +140,17 @@ def __init__(
140140
super().__init__(agent_data=agent_data, **kwargs)
141141
self._model_id = model_id
142142

143-
def get_model_adapter(self, model_id: str, **kwargs: Any) -> ToolCallingGoogleGenAIAdapter:
143+
def get_model_adapter(self, model_id: str, **kwargs: Any) -> GoogleGenAIModelAdapter:
144144
"""Create a Google GenAI model adapter with tool calling support.
145145
146146
Args:
147147
model_id: Model identifier
148148
**kwargs: Additional arguments (e.g., register_name for tracing)
149149
150150
Returns:
151-
Configured ToolCallingGoogleGenAIAdapter
151+
Configured GoogleGenAIModelAdapter
152152
"""
153-
adapter = ToolCallingGoogleGenAIAdapter(get_google_client(), model_id=model_id)
153+
adapter = GoogleGenAIModelAdapter(get_google_client(), model_id=model_id)
154154

155155
# Register for tracing if requested
156156
if "register_name" in kwargs:

maseval/benchmark/tau2/tau2.py

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -573,16 +573,16 @@ def _generate_with_tools(self) -> str:
573573
# Build messages for LLM call
574574
messages = [{"role": "system", "content": self.system_prompt}] + self._messages
575575

576-
# Generate response with tool access
577-
response = self.model.generate(
576+
# Generate response with tool access using chat() method
577+
response = self.model.chat(
578578
messages=messages,
579579
tools=self._get_tool_definitions(),
580580
**self.llm_args,
581581
)
582582

583-
# Parse response
584-
content = response.get("content", "")
585-
tool_calls = response.get("tool_calls", [])
583+
# Parse response from ChatResponse
584+
content = response.content or ""
585+
tool_calls = response.tool_calls or []
586586

587587
if tool_calls:
588588
# Add assistant message with tool calls
@@ -673,24 +673,37 @@ def _get_tool_definitions(self) -> List[Dict[str, Any]]:
673673
if param_name == "self":
674674
continue
675675

676-
# Determine parameter type
677-
param_type = "string" # Default
676+
# Determine parameter type and build property schema
677+
param_schema: Dict[str, Any] = {"description": f"Parameter: {param_name}"}
678+
678679
if param.annotation is not inspect.Parameter.empty:
679680
if param.annotation is int:
680-
param_type = "integer"
681+
param_schema["type"] = "integer"
681682
elif param.annotation is float:
682-
param_type = "number"
683+
param_schema["type"] = "number"
683684
elif param.annotation is bool:
684-
param_type = "boolean"
685+
param_schema["type"] = "boolean"
685686
elif param.annotation is list or (hasattr(param.annotation, "__origin__") and param.annotation.__origin__ is list):
686-
param_type = "array"
687+
param_schema["type"] = "array"
688+
# Add items schema for array types (required by Google GenAI)
689+
param_schema["items"] = {"type": "string"}
690+
# Try to get the inner type for List[X]
691+
if hasattr(param.annotation, "__args__") and param.annotation.__args__:
692+
inner_type = param.annotation.__args__[0]
693+
if inner_type is int:
694+
param_schema["items"] = {"type": "integer"}
695+
elif inner_type is float:
696+
param_schema["items"] = {"type": "number"}
697+
elif inner_type is bool:
698+
param_schema["items"] = {"type": "boolean"}
687699
elif param.annotation is dict:
688-
param_type = "object"
700+
param_schema["type"] = "object"
701+
else:
702+
param_schema["type"] = "string"
703+
else:
704+
param_schema["type"] = "string"
689705

690-
properties[param_name] = {
691-
"type": param_type,
692-
"description": f"Parameter: {param_name}",
693-
}
706+
properties[param_name] = param_schema
694707

695708
# Check if parameter is required (no default value)
696709
if param.default is inspect.Parameter.empty:
@@ -757,6 +770,26 @@ def get_messages(self) -> Any:
757770
"""
758771
return self._agent.get_messages()
759772

773+
def gather_traces(self) -> Dict[str, Any]:
774+
"""Gather execution traces from this agent.
775+
776+
Overrides base implementation to handle list-based message history.
777+
"""
778+
history = self.get_messages()
779+
# history is already a list, not a MessageHistory object
780+
messages = history if isinstance(history, list) else []
781+
return {
782+
"type": type(self).__name__,
783+
"gathered_at": __import__("datetime").datetime.now().isoformat(),
784+
"name": self.name,
785+
"agent_type": type(self.agent).__name__,
786+
"adapter_type": type(self).__name__,
787+
"message_count": len(messages),
788+
"messages": messages,
789+
"callbacks": [type(cb).__name__ for cb in self.callbacks],
790+
"logs": self.logs,
791+
}
792+
760793

761794
class DefaultAgentTau2Benchmark(Tau2Benchmark):
762795
"""Tau2 benchmark with default agent implementation.

0 commit comments

Comments
 (0)