@@ -62,6 +62,7 @@ def get_model_adapter(self, model_id, **kwargs):
6262
6363from maseval import AgentAdapter , Benchmark , Evaluator , ModelAdapter , Task , User
6464from maseval .core .user import AgenticUser
65+ from maseval .core .callback import BenchmarkCallback
6566
6667from maseval .benchmark .tau2 .environment import Tau2Environment
6768from maseval .benchmark .tau2 .evaluator import Tau2Evaluator
@@ -231,41 +232,44 @@ def get_model_adapter(self, model_id, **kwargs):
231232 benchmark.run(tasks)
232233 """
233234
234- def __init__ (self , * args : Any , max_invocations : int = 50 , ** kwargs : Any ):
235+ # Maximum agent-user interaction rounds (tau2-bench uses max_steps=200, where 1 turn ≈ 4 steps)
236+ MAX_INVOCATIONS = 50
237+
238+ def __init__ (
239+ self ,
240+ callbacks : Optional [List [BenchmarkCallback ]] = None ,
241+ n_task_repeats : int = 1 ,
242+ max_invocations : int = MAX_INVOCATIONS ,
243+ num_workers : int = 1 ,
244+ fail_on_setup_error : bool = False ,
245+ fail_on_task_error : bool = False ,
246+ fail_on_evaluation_error : bool = False ,
247+ progress_bar : bool | str = True ,
248+ ):
235249 """Initialize benchmark with tau2-specific defaults.
236250
237251 Args:
252+ callbacks: Optional list of callback handlers for monitoring execution.
253+ n_task_repeats: Number of times to repeat each task. Default 1.
238254 max_invocations: Maximum agent-user interaction rounds (default: 50).
239255 tau2-bench uses max_steps=200, where 1 turn ≈ 4 steps.
240-
241- Inherited from Benchmark (pass via kwargs):
242256 num_workers: Number of parallel task executions. Default 1 (sequential).
243- Set higher for I/O-bound workloads (e.g., LLM API calls).
244- n_task_repeats: Number of times to repeat each task. Default 1.
245- Useful for measuring variance or computing pass@k metrics.
246- callbacks: List of callback handlers for monitoring execution.
247- progress_bar: Progress display. True (default) for tqdm, "rich" for Rich,
248- or False to disable.
257+ fail_on_setup_error: If True, raise on setup errors. Default False.
249258 fail_on_task_error: If True, raise on task execution errors. Default False.
250259 fail_on_evaluation_error: If True, raise on evaluation errors. Default False.
251- fail_on_setup_error: If True, raise on setup errors. Default False.
252-
253- Example:
254- ```python
255- # Parallel execution for faster evaluation
256- benchmark = MyTau2Benchmark(num_workers=4)
257-
258- # Multiple repeats for pass@k metrics
259- benchmark = MyTau2Benchmark(n_task_repeats=4)
260-
261- # Debug mode - fail fast on errors
262- benchmark = MyTau2Benchmark(
263- fail_on_task_error=True,
264- fail_on_evaluation_error=True,
265- )
266- ```
260+ progress_bar: Progress display. True (default) for tqdm, "rich" for Rich,
261+ or False to disable.
267262 """
268- super ().__init__ (* args , max_invocations = max_invocations , ** kwargs ) # type: ignore[parameter-already-assigned]
263+ super ().__init__ (
264+ callbacks = callbacks ,
265+ n_task_repeats = n_task_repeats ,
266+ max_invocations = max_invocations ,
267+ num_workers = num_workers ,
268+ fail_on_setup_error = fail_on_setup_error ,
269+ fail_on_task_error = fail_on_task_error ,
270+ fail_on_evaluation_error = fail_on_evaluation_error ,
271+ progress_bar = progress_bar ,
272+ )
269273
270274 def _get_user_model_id (self , task : Task ) -> str :
271275 """Get user simulator model ID from task.user_data.
@@ -875,14 +879,6 @@ class DefaultAgentTau2Benchmark(Tau2Benchmark):
875879 results = benchmark.run(tasks)
876880 """
877881
878- # Cache for model adapters
879- _model_cache : Dict [str , ModelAdapter ]
880-
881- def __init__ (self , * args : Any , ** kwargs : Any ):
882- """Initialize the default agent benchmark. See Tau2Benchmark for args."""
883- super ().__init__ (* args , ** kwargs )
884- self ._model_cache = {}
885-
886882 def _get_agent_model_id (self , agent_data : Dict [str , Any ]) -> str :
887883 """Get agent model ID from agent_data.
888884
@@ -965,5 +961,9 @@ def get_model_adapter(self, model_id: str, **kwargs: Any) -> ModelAdapter:
965961
966962 Returns:
967963 ModelAdapter instance
964+
965+ Note:
966+ DefaultAgentTau2Benchmark uses lazy initialization for model caching.
967+ Access via `getattr(self, '_model_cache', {})` in subclass implementations.
968968 """
969969 pass
0 commit comments