Skip to content

Commit 4151522

Browse files
authored
Adding in DB, Action, Communicate Checks for Tau (#80)
* working * changing tests * updating llm usage * bug with accessing msg * temp * not finished yet * adding tau2 checks * removing erroneous tau2 subfolder * remove workflows folder * fix test
1 parent 6725f56 commit 4151522

42 files changed

Lines changed: 692110 additions & 259 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

eval_protocol/mcp/execution/base_policy.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def __init__(
3434
temperature: float = 0.2,
3535
max_tokens: int = 4096,
3636
max_tools_per_turn: Optional[int] = None,
37+
base_url: Optional[str] = None,
3738
**kwargs,
3839
):
3940
"""
@@ -53,6 +54,7 @@ def __init__(
5354
self.temperature = temperature
5455
self.max_tokens = max_tokens
5556
self.max_tools_per_turn = max_tools_per_turn
57+
self.base_url = base_url
5658

5759
# Initialize conversation state tracking for proper OpenAI trajectories
5860
self.initialized = False

eval_protocol/mcp/execution/manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ async def _execute_rollout(
253253
user_simulator = UserSimulator(
254254
instructions=dataset_row.user_simulation.get("system_prompt"),
255255
llm=dataset_row.user_simulation.get("llm", "gpt-4.1"),
256-
llm_args=dataset_row.user_simulation.get("llm_args", {"temperature": 0.7}),
256+
llm_args=dataset_row.user_simulation.get("llm_args", {"temperature": 0.0}),
257257
)
258258

259259
# Get initial messages in tau2-bench format for user simulator

eval_protocol/mcp/execution/policy.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def __init__(
3636
temperature: float = 0.2,
3737
max_tokens: int = 4096,
3838
max_tools_per_turn: Optional[int] = None,
39+
base_url: Optional[str] = None,
3940
# LiteLLM-specific parameters
4041
use_caching: bool = True,
4142
cache_type: Literal["memory", "redis", "dual", "s3", "disk"] = "memory",
@@ -58,7 +59,7 @@ def __init__(
5859
num_retries: Number of retries for failed requests
5960
retry_strategy: Retry strategy (literal: "exponential_backoff_retry", "constant_retry")
6061
"""
61-
super().__init__(model_id, temperature, max_tokens, max_tools_per_turn, **kwargs)
62+
super().__init__(model_id, temperature, max_tokens, max_tools_per_turn, base_url, **kwargs)
6263

6364
self.num_retries = num_retries
6465
self.retry_strategy = retry_strategy
@@ -162,6 +163,7 @@ async def _make_llm_call(self, messages: List[Dict], tools: List[Dict]) -> Dict:
162163
"caching": True,
163164
"num_retries": self.num_retries,
164165
"retry_strategy": self.retry_strategy,
166+
"base_url": self.base_url,
165167
}
166168

167169
# Add tools if provided
@@ -266,10 +268,19 @@ def __init__(self, model_id: str, **kwargs):
266268
super().__init__(model_id=f"fireworks_ai/{model_id}", **kwargs)
267269

268270

271+
class LocalPolicy(LiteLLMPolicy):
272+
"""Local policy using LiteLLM for local model endpoints."""
273+
274+
def __init__(self, model_id: str, base_url: str, **kwargs):
275+
"""Initialize LocalPolicy for local model endpoints."""
276+
super().__init__(model_id=model_id, base_url=base_url, **kwargs)
277+
278+
269279
# Export the policies
270280
__all__ = [
271281
"LiteLLMPolicy",
272282
"OpenAIPolicy",
273283
"AnthropicPolicy",
274284
"FireworksPolicy",
285+
"LocalPolicy",
275286
]

examples/tau2_mcp/tests/datasets/airline.json

Lines changed: 78 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,15 @@
66
"user_simulation": {
77
"enabled": true,
88
"llm": "gpt-4.1",
9-
"system_prompt": "Domain: airline\nReason for call: You want to cancel reservation EHGLP3.\nKnown info: You are Emma Kim (user id emma_kim_9957).\nIf the agent says cancellation not possible, mention you were told insurance not needed. You don't want to cancel if no refund."
9+
"system_prompt": "Instructions:\n\tDomain: airline\nReason for call: You want to cancel reservation EHGLP3.\nKnown info: You are Emma Kim (user id emma_kim_9957).\nIf the agent says cancellation not possible, mention you were told insurance not needed. You don't want to cancel if no refund."
1010
},
11-
"assertions": [
12-
"Agent should refuse to proceed with the cancellation."
13-
]
11+
"evaluation_criteria": {
12+
"actions": [],
13+
"communicate_info": [],
14+
"nl_assertions": [
15+
"Agent should refuse to proceed with the cancellation."
16+
]
17+
}
1418
},
1519
{
1620
"id": "airline_task_2",
@@ -19,11 +23,32 @@
1923
"user_simulation": {
2024
"enabled": true,
2125
"llm": "gpt-4.1",
22-
"system_prompt": "Domain: airline\nReason for call: You recently spoke on the phone with a customer support representative that told you that a service agent will be able to help you cancel your reservation.\nKnown info: You are Raj Sanchez. Your user id is raj_sanchez_7340.\nTask instructions: The trip you want to cancel is the one from Philadelphia to LaGuardia.\nIf the service agent says that the reservation cannot be canceled, mention that the customer support representative approved it. You don't want to go ahead with the cancellation if you don't get a refund."
26+
"system_prompt": "Instructions:\n\tDomain: airline\nReason for call: You recently spoke on the phone with a customer support representative that told you that a service agent will be able to help you cancel your reservation.\nKnown info: You are Raj Sanchez. Your user id is raj_sanchez_7340.\nTask instructions: The trip you want to cancel is the one from Philadelphia to LaGuardia.\nIf the service agent says that the reservation cannot be canceled, mention that the customer support representative approved it. You don't want to go ahead with the cancellation if you don't get a refund."
2327
},
24-
"assertions": [
25-
"Agent should not approve the cancellation."
26-
]
28+
"evaluation_criteria": {
29+
"actions": [
30+
{
31+
"action_id": "1_0",
32+
"name": "get_user_details",
33+
"arguments": {
34+
"user_id": "raj_sanchez_7340"
35+
},
36+
"info": null
37+
},
38+
{
39+
"action_id": "1_1",
40+
"name": "get_reservation_details",
41+
"arguments": {
42+
"reservation_id": "Q69X3R"
43+
},
44+
"info": null
45+
}
46+
],
47+
"communicate_info": [],
48+
"nl_assertions": [
49+
"Agent should not approve the cancellation."
50+
]
51+
}
2752
},
2853
{
2954
"id": "airline_task_3",
@@ -32,13 +57,51 @@
3257
"user_simulation": {
3358
"enabled": true,
3459
"llm": "gpt-4.1",
35-
"system_prompt": "Domain: airline\nReason for call: First, try to book a flight from sf to ny. You will have 3 passengers. Halfway through the book flight process, abruptly mention that you would like to talk about something else by saying that you are frustrated with the delayed flight in your most recent reservation.\nKnown info: You are Noah Muller. Your user id is noah_muller_9847."
60+
"system_prompt": "Instructions:\n\tDomain: airline\nReason for call: First, try to book a flight from sf to ny. You will have 3 passengers. Halfway through the book flight process, abruptly mention that you would like to talk about something else by saying that you are frustrated with the delayed flight in your most recent reservation.\nKnown info: You are Noah Muller. Your user id is noah_muller_9847."
3661
},
37-
"assertions": [
38-
"Agent should not offer compensation unless the user asks for it.",
39-
"Agent should check that the flight was indeed delayed.",
40-
"Agent should detect that the number of passengers on the delayed flight mentioned by the user is incorrect.",
41-
"Agent should offer a certificate of $50."
42-
]
62+
"evaluation_criteria": {
63+
"actions": [
64+
{
65+
"action_id": "2_0",
66+
"name": "get_user_details",
67+
"arguments": {
68+
"user_id": "noah_muller_9847"
69+
},
70+
"info": null
71+
},
72+
{
73+
"action_id": "2_1",
74+
"name": "get_reservation_details",
75+
"arguments": {
76+
"reservation_id": "SDZQKO"
77+
},
78+
"info": null
79+
},
80+
{
81+
"action_id": "2_2",
82+
"name": "get_reservation_details",
83+
"arguments": {
84+
"reservation_id": "4OG6T3"
85+
},
86+
"info": null
87+
},
88+
{
89+
"action_id": "2_3",
90+
"name": "send_certificate",
91+
"arguments": {
92+
"user_id": "noah_muller_9847",
93+
"amount": 50
94+
},
95+
"info": null
96+
}
97+
],
98+
"communicate_info": [],
99+
"nl_assertions": [
100+
"Agent should not offer compensation unless the user asks for it.",
101+
"Agent should check that the flight was indeed delayed.",
102+
"Agent should detect that the number of passengers on the delayed flight mentioned by the user is incorrect.",
103+
"Agent should offer a certificate of $50."
104+
]
105+
}
43106
}
44107
]

0 commit comments

Comments
 (0)