@@ -84,14 +84,14 @@ def setup_state(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
8484
8585 Args:
8686 task_data: Must contain ``"query"`` (str) and ``"environment_data"``
87- (dict with optional ``"choices"``, ``"full_prompt"``, ``"use_full_prompt"``).
87+ (dict with ``"choices"``, ``"full_prompt"``, ``"use_full_prompt"``).
8888 """
8989 env_data = task_data ["environment_data" ]
9090 return {
9191 "query" : task_data ["query" ],
92- "choices" : env_data . get ( "choices" , DEFAULT_CHOICES ) ,
93- "full_prompt" : env_data . get ( "full_prompt" , "" ) ,
94- "use_full_prompt" : env_data . get ( "use_full_prompt" , False ) ,
92+ "choices" : env_data [ "choices" ] ,
93+ "full_prompt" : env_data [ "full_prompt" ] ,
94+ "use_full_prompt" : env_data [ "use_full_prompt" ] ,
9595 }
9696
9797 def create_tools (self ) -> Dict [str , Any ]:
@@ -137,7 +137,7 @@ def __init__(
137137 self .task = task
138138 self .environment = environment
139139 self .gold = task .evaluation_data ["gold" ]
140- self .choices = task .environment_data . get ( "choices" , DEFAULT_CHOICES )
140+ self .choices = task .environment_data [ "choices" ]
141141
142142 def filter_traces (self , traces : Dict [str , Any ]) -> Dict [str , Any ]:
143143 """Extract relevant traces for evaluation.
@@ -175,11 +175,11 @@ def __call__(self, traces: Dict[str, Any], final_answer: Optional[str] = None) -
175175 "predicted" : predicted ,
176176 "gold" : self .gold ,
177177 "correct" : correct ,
178- "doc_id" : self .task .metadata . get ( "doc_id" ) ,
178+ "doc_id" : self .task .metadata [ "doc_id" ] ,
179179 }
180180
181181 # Extract logprobs from traces if available (for logprobs-based evaluation)
182- messages = traces . get ( "messages" , [])
182+ messages = traces [ "messages" ]
183183 for msg in messages :
184184 if isinstance (msg , dict ) and "logprobs" in msg :
185185 result ["logprobs" ] = msg ["logprobs" ]
@@ -445,7 +445,7 @@ def precompute_all_logprobs_lmeval(self, tasks: Sequence[Task]) -> Dict[Any, Lis
445445 instance_map = {} # (doc_id, choice_idx) -> position in results
446446
447447 for task in tasks :
448- doc_id = task .metadata . get ( "doc_id" )
448+ doc_id = task .metadata [ "doc_id" ]
449449 # Get prompt from task - use full_prompt from environment_data if available
450450 if self .use_full_prompt and "full_prompt" in task .environment_data :
451451 prompt = task .environment_data ["full_prompt" ]
@@ -471,7 +471,7 @@ def precompute_all_logprobs_lmeval(self, tasks: Sequence[Task]) -> Dict[Any, Lis
471471 # Map results back to doc_ids
472472 doc_logprobs = {}
473473 for task in tasks :
474- doc_id = task .metadata . get ( "doc_id" )
474+ doc_id = task .metadata [ "doc_id" ]
475475 logprobs = []
476476 for i in range (len (choices )):
477477 pos = instance_map [(doc_id , i )]
@@ -498,20 +498,19 @@ def run_agents(
498498 which automatically picks single-token or multi-token scoring.
499499 """
500500 prompt = environment .get_prompt ()
501- choices = environment .state .get ("choices" , DEFAULT_CHOICES )
502- doc_id = task .metadata .get ("doc_id" ) if task else None
503-
504- if hasattr (self , "_precomputed_logprobs" ) and doc_id is not None :
505- logprobs = self ._precomputed_logprobs .get (doc_id )
506- if logprobs is not None :
507- best_idx = logprobs .index (max (logprobs ))
508- answer = choices [best_idx ]
509- environment .state ["logprobs" ] = logprobs
510- environment .state ["predicted_idx" ] = best_idx
511- agent = agents [0 ]
512- agent ._messages .append ({"role" : "user" , "content" : prompt })
513- agent ._messages .append ({"role" : "assistant" , "content" : answer , "logprobs" : logprobs })
514- return answer
501+ choices = environment .state ["choices" ]
502+ doc_id = task .metadata ["doc_id" ]
503+
504+ if hasattr (self , "_precomputed_logprobs" ) and doc_id in self ._precomputed_logprobs :
505+ logprobs = self ._precomputed_logprobs [doc_id ]
506+ best_idx = logprobs .index (max (logprobs ))
507+ answer = choices [best_idx ]
508+ environment .state ["logprobs" ] = logprobs
509+ environment .state ["predicted_idx" ] = best_idx
510+ agent = agents [0 ]
511+ agent ._messages .append ({"role" : "user" , "content" : prompt })
512+ agent ._messages .append ({"role" : "assistant" , "content" : answer , "logprobs" : logprobs })
513+ return answer
515514
516515 logprobs = self ._scorer .loglikelihood_choices (prompt , choices , delimiter = TARGET_DELIMITER )
517516
@@ -677,14 +676,14 @@ def compute_benchmark_metrics(results: List[Dict[str, Any]]) -> Dict[str, Any]:
677676 acc_norm_sum = 0.0
678677
679678 for res in results :
680- if res . get ( "status" ) != STATUS_SUCCESS :
679+ if res [ "status" ] != STATUS_SUCCESS :
681680 continue
682681
683- evals = res . get ( "eval" ) or []
682+ evals = res [ "eval" ] or []
684683 for entry in evals :
685- acc_sum += entry . get ( "acc" , 0.0 )
686- acc_norm_sum += entry . get ( "acc_norm" , 0.0 )
687- if entry . get ( "correct" , False ) :
684+ acc_sum += entry [ "acc" ]
685+ acc_norm_sum += entry [ "acc_norm" ]
686+ if entry [ "correct" ] :
688687 correct_count += 1
689688
690689 return {
0 commit comments