Skip to content

Commit d49f1bb

Browse files
committed
fixed tau documentation
1 parent b1ea248 commit d49f1bb

3 files changed

Lines changed: 44 additions & 38 deletions

File tree

maseval/benchmark/tau2/data_loader.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -327,9 +327,7 @@ def load_tasks(
327327
# Convert to MASEval Task objects
328328
tasks = []
329329
for raw_task in raw_tasks:
330-
task = _convert_tau2_task_to_maseval(
331-
raw_task, domain, split, domain_config, timeout_seconds, max_retries
332-
)
330+
task = _convert_tau2_task_to_maseval(raw_task, domain, split, domain_config, timeout_seconds, max_retries)
333331
tasks.append(task)
334332

335333
return TaskQueue(tasks)

maseval/benchmark/tau2/tau2.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def get_model_adapter(self, model_id, **kwargs):
6262

6363
from maseval import AgentAdapter, Benchmark, Evaluator, ModelAdapter, Task, User
6464
from maseval.core.user import AgenticUser
65+
from maseval.core.callback import BenchmarkCallback
6566

6667
from maseval.benchmark.tau2.environment import Tau2Environment
6768
from maseval.benchmark.tau2.evaluator import Tau2Evaluator
@@ -231,41 +232,44 @@ def get_model_adapter(self, model_id, **kwargs):
231232
benchmark.run(tasks)
232233
"""
233234

234-
def __init__(self, *args: Any, max_invocations: int = 50, **kwargs: Any):
235+
# Maximum agent-user interaction rounds (tau2-bench uses max_steps=200, where 1 turn ≈ 4 steps)
236+
MAX_INVOCATIONS = 50
237+
238+
def __init__(
239+
self,
240+
callbacks: Optional[List[BenchmarkCallback]] = None,
241+
n_task_repeats: int = 1,
242+
max_invocations: int = MAX_INVOCATIONS,
243+
num_workers: int = 1,
244+
fail_on_setup_error: bool = False,
245+
fail_on_task_error: bool = False,
246+
fail_on_evaluation_error: bool = False,
247+
progress_bar: bool | str = True,
248+
):
235249
"""Initialize benchmark with tau2-specific defaults.
236250
237251
Args:
252+
callbacks: Optional list of callback handlers for monitoring execution.
253+
n_task_repeats: Number of times to repeat each task. Default 1.
238254
max_invocations: Maximum agent-user interaction rounds (default: 50).
239255
tau2-bench uses max_steps=200, where 1 turn ≈ 4 steps.
240-
241-
Inherited from Benchmark (pass via kwargs):
242256
num_workers: Number of parallel task executions. Default 1 (sequential).
243-
Set higher for I/O-bound workloads (e.g., LLM API calls).
244-
n_task_repeats: Number of times to repeat each task. Default 1.
245-
Useful for measuring variance or computing pass@k metrics.
246-
callbacks: List of callback handlers for monitoring execution.
247-
progress_bar: Progress display. True (default) for tqdm, "rich" for Rich,
248-
or False to disable.
257+
fail_on_setup_error: If True, raise on setup errors. Default False.
249258
fail_on_task_error: If True, raise on task execution errors. Default False.
250259
fail_on_evaluation_error: If True, raise on evaluation errors. Default False.
251-
fail_on_setup_error: If True, raise on setup errors. Default False.
252-
253-
Example:
254-
```python
255-
# Parallel execution for faster evaluation
256-
benchmark = MyTau2Benchmark(num_workers=4)
257-
258-
# Multiple repeats for pass@k metrics
259-
benchmark = MyTau2Benchmark(n_task_repeats=4)
260-
261-
# Debug mode - fail fast on errors
262-
benchmark = MyTau2Benchmark(
263-
fail_on_task_error=True,
264-
fail_on_evaluation_error=True,
265-
)
266-
```
260+
progress_bar: Progress display. True (default) for tqdm, "rich" for Rich,
261+
or False to disable.
267262
"""
268-
super().__init__(*args, max_invocations=max_invocations, **kwargs) # type: ignore[parameter-already-assigned]
263+
super().__init__(
264+
callbacks=callbacks,
265+
n_task_repeats=n_task_repeats,
266+
max_invocations=max_invocations,
267+
num_workers=num_workers,
268+
fail_on_setup_error=fail_on_setup_error,
269+
fail_on_task_error=fail_on_task_error,
270+
fail_on_evaluation_error=fail_on_evaluation_error,
271+
progress_bar=progress_bar,
272+
)
269273

270274
def _get_user_model_id(self, task: Task) -> str:
271275
"""Get user simulator model ID from task.user_data.
@@ -875,14 +879,6 @@ class DefaultAgentTau2Benchmark(Tau2Benchmark):
875879
results = benchmark.run(tasks)
876880
"""
877881

878-
# Cache for model adapters
879-
_model_cache: Dict[str, ModelAdapter]
880-
881-
def __init__(self, *args: Any, **kwargs: Any):
882-
"""Initialize the default agent benchmark. See Tau2Benchmark for args."""
883-
super().__init__(*args, **kwargs)
884-
self._model_cache = {}
885-
886882
def _get_agent_model_id(self, agent_data: Dict[str, Any]) -> str:
887883
"""Get agent model ID from agent_data.
888884
@@ -965,5 +961,9 @@ def get_model_adapter(self, model_id: str, **kwargs: Any) -> ModelAdapter:
965961
966962
Returns:
967963
ModelAdapter instance
964+
965+
Note:
966+
DefaultAgentTau2Benchmark uses lazy initialization for model caching.
967+
Access via `getattr(self, '_model_cache', {})` in subclass implementations.
968968
"""
969969
pass

tests/test_benchmarks/test_tau2/test_default_agent.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,8 @@ def test_init_basic(self):
503503
"""Test basic initialization."""
504504
benchmark = DummyDefaultAgentBenchmark()
505505

506-
assert benchmark._model_cache == {}
506+
# Benchmark should be initialized successfully
507+
assert benchmark.max_invocations == 50 # Tau2 default
507508

508509
def test_init_with_all_options(self):
509510
"""Test initialization with all options."""
@@ -515,6 +516,13 @@ def test_init_with_all_options(self):
515516
assert benchmark.n_task_repeats == 3
516517
assert benchmark.max_invocations == 5
517518

519+
def test_default_max_invocations(self):
520+
"""Test that default max_invocations is 50 from class attribute."""
521+
benchmark = DummyDefaultAgentBenchmark()
522+
523+
assert benchmark.max_invocations == 50
524+
assert benchmark.MAX_INVOCATIONS == 50
525+
518526

519527
@pytest.mark.benchmark
520528
class TestDefaultAgentTau2BenchmarkSetupAgents:

0 commit comments

Comments
 (0)