Skip to content

Commit eb24ce1

Browse files
committed
fixed bugs
1 parent 6087c8f commit eb24ce1

3 files changed

Lines changed: 27 additions & 4 deletions

File tree

examples/tau2_benchmark/tau2_default_agent_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def run_benchmark(
287287
callbacks=[logger],
288288
n_task_repeats=n_task_repeats,
289289
fail_on_setup_error=True,
290-
fail_on_task_error=False, # Continue on task errors
290+
fail_on_task_error=False, # Set to False to continue on task errors
291291
fail_on_evaluation_error=True,
292292
)
293293

maseval/benchmark/tau2/domains/retail/tools.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ def _get_order(self, order_id: str) -> Order:
5252
"""Get the order from the database.
5353
5454
Args:
55-
order_id: The order id, such as '#W0000000'. Be careful there is a '#' symbol at the beginning.
55+
order_id: The order id, such as '#W0000000' or 'W0000000'.
56+
The '#' prefix is optional and will be added if missing.
5657
5758
Returns:
5859
The order.
@@ -62,6 +63,9 @@ def _get_order(self, order_id: str) -> Order:
6263
"""
6364
if self.db is None:
6465
raise ValueError("Database not initialized")
66+
# Normalize order_id: add '#' prefix if missing (LLMs often omit it)
67+
if not order_id.startswith("#"):
68+
order_id = f"#{order_id}"
6569
if order_id not in self.db.orders:
6670
raise ValueError("Order not found")
6771
return self.db.orders[order_id]

maseval/interface/inference/google_genai.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,26 @@ def _convert_messages(self, messages: List[Dict[str, Any]]) -> tuple[Optional[st
168168
if role == "system":
169169
system_instruction = content
170170
elif role == "assistant":
171-
contents.append({"role": "model", "parts": [{"text": content}]})
171+
# Handle assistant messages with or without tool calls
172+
parts = []
173+
if content:
174+
parts.append({"text": content})
175+
# Convert tool_calls to Google's function_call format
176+
tool_calls = msg.get("tool_calls", [])
177+
if tool_calls:
178+
import json
179+
180+
for tc in tool_calls:
181+
if tc.get("type") == "function":
182+
func = tc.get("function", {})
183+
args_str = func.get("arguments", "{}")
184+
try:
185+
args = json.loads(args_str) if isinstance(args_str, str) else args_str
186+
except json.JSONDecodeError:
187+
args = {}
188+
parts.append({"function_call": {"name": func.get("name", ""), "args": args}})
189+
if parts:
190+
contents.append({"role": "model", "parts": parts})
172191
elif role == "tool":
173192
# Tool response in Google format
174193
tool_call_id = msg.get("tool_call_id", "")
@@ -237,7 +256,7 @@ def _parse_response(self, response: Any) -> ChatResponse:
237256
tool_calls = None
238257
if hasattr(response, "candidates") and response.candidates:
239258
candidate = response.candidates[0]
240-
if hasattr(candidate, "content") and candidate.content:
259+
if hasattr(candidate, "content") and candidate.content and candidate.content.parts:
241260
for part in candidate.content.parts:
242261
if hasattr(part, "function_call") and part.function_call:
243262
if tool_calls is None:

0 commit comments

Comments
 (0)