Skip to content

Commit fc66d36

Browse files
committed
fixed bug in tau2 implementation
1 parent c092da1 commit fc66d36

3 files changed

Lines changed: 49 additions & 46 deletions

File tree

maseval/benchmark/tau2/tau2.py

Lines changed: 38 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -736,71 +736,64 @@ def _execute_tool_call(self, tool_call: Dict[str, Any]) -> Any:
736736
def _get_tool_definitions(self) -> List[Dict[str, Any]]:
737737
"""Generate tool definitions for the LLM.
738738
739+
Uses docstring_parser and Pydantic create_model to build parameter
740+
schemas, matching the original tau2-bench Tool.openai_schema approach.
741+
739742
Returns:
740743
List of tool definitions in OpenAI function calling format
741744
"""
742745
import inspect
746+
from typing import Any as TypingAny
747+
748+
from docstring_parser import parse as parse_docstring
749+
from pydantic import Field, create_model
743750

744751
definitions = []
745752
for name, func in self.tools.items():
746753
sig = inspect.signature(func)
747-
doc = func.__doc__ or f"Tool: {name}"
754+
doc = parse_docstring(func.__doc__ or "")
748755

749-
# Build parameters schema
750-
properties = {}
751-
required = []
756+
# Build tool description from parsed docstring (short + long)
757+
if doc.short_description:
758+
description = doc.short_description
759+
if doc.long_description:
760+
description += "\n\n" + doc.long_description
761+
else:
762+
description = name
763+
764+
# Build Pydantic model from signature + docstring params
765+
doc_params = {p.arg_name: p for p in doc.params}
766+
model_fields = {}
752767

753768
for param_name, param in sig.parameters.items():
754769
if param_name == "self":
755770
continue
756771

757-
# Determine parameter type and build property schema
758-
param_schema: Dict[str, Any] = {"description": f"Parameter: {param_name}"}
759-
760-
if param.annotation is not inspect.Parameter.empty:
761-
if param.annotation is int:
762-
param_schema["type"] = "integer"
763-
elif param.annotation is float:
764-
param_schema["type"] = "number"
765-
elif param.annotation is bool:
766-
param_schema["type"] = "boolean"
767-
elif param.annotation is list or (hasattr(param.annotation, "__origin__") and param.annotation.__origin__ is list):
768-
param_schema["type"] = "array"
769-
# Add items schema for array types (required by Google GenAI)
770-
param_schema["items"] = {"type": "string"}
771-
# Try to get the inner type for List[X]
772-
if hasattr(param.annotation, "__args__") and param.annotation.__args__:
773-
inner_type = param.annotation.__args__[0]
774-
if inner_type is int:
775-
param_schema["items"] = {"type": "integer"}
776-
elif inner_type is float:
777-
param_schema["items"] = {"type": "number"}
778-
elif inner_type is bool:
779-
param_schema["items"] = {"type": "boolean"}
780-
elif param.annotation is dict:
781-
param_schema["type"] = "object"
782-
else:
783-
param_schema["type"] = "string"
784-
else:
785-
param_schema["type"] = "string"
786-
787-
properties[param_name] = param_schema
788-
789-
# Check if parameter is required (no default value)
790-
if param.default is inspect.Parameter.empty:
791-
required.append(param_name)
772+
anno = param.annotation
773+
default = param.default
774+
775+
if default is param.empty:
776+
default = ... # required
777+
778+
if param_name in doc_params:
779+
default = Field(default, description=doc_params[param_name].description)
780+
if (anno is param.empty) and (doc_params[param_name].type_name is not None):
781+
anno = doc_params[param_name].type_name
782+
783+
if anno is param.empty:
784+
anno = TypingAny
785+
786+
model_fields[param_name] = (anno, default)
787+
788+
params_model = create_model("parameters", **model_fields) # type: ignore[call-overload]
792789

793790
definitions.append(
794791
{
795792
"type": "function",
796793
"function": {
797794
"name": name,
798-
"description": doc.strip().split("\n")[0], # First line of docstring
799-
"parameters": {
800-
"type": "object",
801-
"properties": properties,
802-
"required": required,
803-
},
795+
"description": description,
796+
"parameters": params_model.model_json_schema(),
804797
},
805798
}
806799
)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ multiagentbench = [
8080
# dependency of keybert. Lower versions are incompatible with huggingfacehub
8181
"sentence-transformers>=2.3.0",
8282
]
83-
tau2 = []
83+
tau2 = ["docstring-parser>=0.16"]
8484

8585
# Dependencies for running examples (only what's actually used)
8686
examples = [

uv.lock

Lines changed: 10 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)