Skip to content

Commit 079ef47

Browse files
committed
[Move DISCO queue to core]:
- Replace all .get() calls on required fields by explicit dict lookup.
1 parent 14bcb3f commit 079ef47

3 files changed

Lines changed: 67 additions & 28 deletions

File tree

maseval/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
UserError,
5050
UserExhaustedError,
5151
TaskTimeoutError,
52+
get_with_assert,
5253
validate_argument_type,
5354
validate_required_arguments,
5455
validate_no_extra_arguments,
@@ -106,6 +107,7 @@
106107
"ChatResponse",
107108
"ModelScorer",
108109
# Exceptions and validation
110+
"get_with_assert",
109111
"MASEvalError",
110112
"AgentError",
111113
"EnvironmentError",

maseval/benchmark/mmlu/mmlu.py

Lines changed: 27 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -84,14 +84,14 @@ def setup_state(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
8484
8585
Args:
8686
task_data: Must contain ``"query"`` (str) and ``"environment_data"``
87-
(dict with optional ``"choices"``, ``"full_prompt"``, ``"use_full_prompt"``).
87+
(dict with ``"choices"``, ``"full_prompt"``, ``"use_full_prompt"``).
8888
"""
8989
env_data = task_data["environment_data"]
9090
return {
9191
"query": task_data["query"],
92-
"choices": env_data.get("choices", DEFAULT_CHOICES),
93-
"full_prompt": env_data.get("full_prompt", ""),
94-
"use_full_prompt": env_data.get("use_full_prompt", False),
92+
"choices": env_data["choices"],
93+
"full_prompt": env_data["full_prompt"],
94+
"use_full_prompt": env_data["use_full_prompt"],
9595
}
9696

9797
def create_tools(self) -> Dict[str, Any]:
@@ -137,7 +137,7 @@ def __init__(
137137
self.task = task
138138
self.environment = environment
139139
self.gold = task.evaluation_data["gold"]
140-
self.choices = task.environment_data.get("choices", DEFAULT_CHOICES)
140+
self.choices = task.environment_data["choices"]
141141

142142
def filter_traces(self, traces: Dict[str, Any]) -> Dict[str, Any]:
143143
"""Extract relevant traces for evaluation.
@@ -175,11 +175,11 @@ def __call__(self, traces: Dict[str, Any], final_answer: Optional[str] = None) -
175175
"predicted": predicted,
176176
"gold": self.gold,
177177
"correct": correct,
178-
"doc_id": self.task.metadata.get("doc_id"),
178+
"doc_id": self.task.metadata["doc_id"],
179179
}
180180

181181
# Extract logprobs from traces if available (for logprobs-based evaluation)
182-
messages = traces.get("messages", [])
182+
messages = traces["messages"]
183183
for msg in messages:
184184
if isinstance(msg, dict) and "logprobs" in msg:
185185
result["logprobs"] = msg["logprobs"]
@@ -445,7 +445,7 @@ def precompute_all_logprobs_lmeval(self, tasks: Sequence[Task]) -> Dict[Any, Lis
445445
instance_map = {} # (doc_id, choice_idx) -> position in results
446446

447447
for task in tasks:
448-
doc_id = task.metadata.get("doc_id")
448+
doc_id = task.metadata["doc_id"]
449449
# Get prompt from task - use full_prompt from environment_data if available
450450
if self.use_full_prompt and "full_prompt" in task.environment_data:
451451
prompt = task.environment_data["full_prompt"]
@@ -471,7 +471,7 @@ def precompute_all_logprobs_lmeval(self, tasks: Sequence[Task]) -> Dict[Any, Lis
471471
# Map results back to doc_ids
472472
doc_logprobs = {}
473473
for task in tasks:
474-
doc_id = task.metadata.get("doc_id")
474+
doc_id = task.metadata["doc_id"]
475475
logprobs = []
476476
for i in range(len(choices)):
477477
pos = instance_map[(doc_id, i)]
@@ -498,20 +498,19 @@ def run_agents(
498498
which automatically picks single-token or multi-token scoring.
499499
"""
500500
prompt = environment.get_prompt()
501-
choices = environment.state.get("choices", DEFAULT_CHOICES)
502-
doc_id = task.metadata.get("doc_id") if task else None
503-
504-
if hasattr(self, "_precomputed_logprobs") and doc_id is not None:
505-
logprobs = self._precomputed_logprobs.get(doc_id)
506-
if logprobs is not None:
507-
best_idx = logprobs.index(max(logprobs))
508-
answer = choices[best_idx]
509-
environment.state["logprobs"] = logprobs
510-
environment.state["predicted_idx"] = best_idx
511-
agent = agents[0]
512-
agent._messages.append({"role": "user", "content": prompt})
513-
agent._messages.append({"role": "assistant", "content": answer, "logprobs": logprobs})
514-
return answer
501+
choices = environment.state["choices"]
502+
doc_id = task.metadata["doc_id"]
503+
504+
if hasattr(self, "_precomputed_logprobs") and doc_id in self._precomputed_logprobs:
505+
logprobs = self._precomputed_logprobs[doc_id]
506+
best_idx = logprobs.index(max(logprobs))
507+
answer = choices[best_idx]
508+
environment.state["logprobs"] = logprobs
509+
environment.state["predicted_idx"] = best_idx
510+
agent = agents[0]
511+
agent._messages.append({"role": "user", "content": prompt})
512+
agent._messages.append({"role": "assistant", "content": answer, "logprobs": logprobs})
513+
return answer
515514

516515
logprobs = self._scorer.loglikelihood_choices(prompt, choices, delimiter=TARGET_DELIMITER)
517516

@@ -677,14 +676,14 @@ def compute_benchmark_metrics(results: List[Dict[str, Any]]) -> Dict[str, Any]:
677676
acc_norm_sum = 0.0
678677

679678
for res in results:
680-
if res.get("status") != STATUS_SUCCESS:
679+
if res["status"] != STATUS_SUCCESS:
681680
continue
682681

683-
evals = res.get("eval") or []
682+
evals = res["eval"] or []
684683
for entry in evals:
685-
acc_sum += entry.get("acc", 0.0)
686-
acc_norm_sum += entry.get("acc_norm", 0.0)
687-
if entry.get("correct", False):
684+
acc_sum += entry["acc"]
685+
acc_norm_sum += entry["acc_norm"]
686+
if entry["correct"]:
688687
correct_count += 1
689688

690689
return {

maseval/core/exceptions.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,44 @@ def __init__(
308308
# =============================================================================
309309

310310

311+
def get_with_assert(container: Any, key: Any, error_msg: Optional[str] = None) -> Any:
312+
"""Get a value from a container, raising ``KeyError`` if not found.
313+
314+
Use instead of ``dict.get(key, default)`` when the key is **required**.
315+
A missing key means a bug — not a case to paper over with a fallback.
316+
317+
Supports nested access via a list of keys::
318+
319+
get_with_assert(task, ["metadata", "doc_id"])
320+
# equivalent to: task["metadata"]["doc_id"] but with a clear error
321+
322+
Args:
323+
container: Dictionary or other container supporting ``in`` and ``[]``.
324+
key: Key to look up. Pass a list for nested access.
325+
error_msg: Custom error message. If ``None``, a descriptive default
326+
is generated.
327+
328+
Returns:
329+
The value at the given key.
330+
331+
Raises:
332+
KeyError: If the key is not found in the container.
333+
"""
334+
if isinstance(key, list):
335+
assert len(key) > 0
336+
value = get_with_assert(container, key[0], error_msg)
337+
if len(key) == 1:
338+
return value
339+
return get_with_assert(value, key[1:], error_msg)
340+
341+
if key not in container:
342+
if error_msg is None:
343+
error_msg = f'Required key "{key}" not in container: {container}'
344+
raise KeyError(error_msg)
345+
346+
return container[key]
347+
348+
311349
def validate_argument_type(
312350
value: Any,
313351
expected_type: str,

0 commit comments

Comments
 (0)