Skip to content

Commit acffbbb

Browse files
committed
Added task fields to report in Benchmark.
1 parent 270c7e2 commit acffbbb

6 files changed

Lines changed: 1536 additions & 138 deletions

File tree

maseval/core/benchmark.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1235,6 +1235,11 @@ def _execute_task_repetition(
12351235
"traces": execution_traces,
12361236
"config": execution_configs,
12371237
"eval": eval_results,
1238+
"task": {
1239+
"query": task.query,
1240+
"metadata": dict(task.metadata),
1241+
"protocol": task.protocol.to_dict(),
1242+
},
12381243
}
12391244

12401245
# Clear registry after task repetition completes

maseval/core/callbacks/result_logger.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def __init__(
6262
include_traces: bool = True,
6363
include_config: bool = True,
6464
include_eval: bool = True,
65+
include_task: bool = False,
6566
validate_on_completion: bool = True,
6667
):
6768
"""Initialize the result logger.
@@ -70,12 +71,15 @@ def __init__(
7071
include_traces: If True, include execution traces in logged results
7172
include_config: If True, include configuration in logged results
7273
include_eval: If True, include evaluation results in logged results
74+
include_task: If True, include task data (query, metadata, protocol)
75+
in logged results
7376
validate_on_completion: If True, validate all iterations were logged at end
7477
"""
7578
super().__init__()
7679
self.include_traces = include_traces
7780
self.include_config = include_config
7881
self.include_eval = include_eval
82+
self.include_task = include_task
7983
self.validate_on_completion = validate_on_completion
8084

8185
# Tracking for validation
@@ -173,6 +177,9 @@ def _filter_report(self, report: Dict) -> Dict:
173177
if self.include_eval and "eval" in report:
174178
filtered["eval"] = report["eval"]
175179

180+
if self.include_task and "task" in report:
181+
filtered["task"] = report["task"]
182+
176183
return filtered
177184

178185
def _report_validation_errors(self) -> None:
@@ -306,6 +313,7 @@ def __init__(
306313
include_traces: bool = True,
307314
include_config: bool = True,
308315
include_eval: bool = True,
316+
include_task: bool = False,
309317
validate_on_completion: bool = True,
310318
):
311319
"""Initialize the file logger.
@@ -322,12 +330,15 @@ def __init__(
322330
include_traces: If True, include execution traces in logged results
323331
include_config: If True, include configuration in logged results
324332
include_eval: If True, include evaluation results in logged results
333+
include_task: If True, include task data (query, metadata, protocol)
334+
in logged results
325335
validate_on_completion: If True, validate all iterations were logged
326336
"""
327337
super().__init__(
328338
include_traces=include_traces,
329339
include_config=include_config,
330340
include_eval=include_eval,
341+
include_task=include_task,
331342
validate_on_completion=validate_on_completion,
332343
)
333344

@@ -518,6 +529,7 @@ def _write_metadata(self) -> None:
518529
"include_traces": self.include_traces,
519530
"include_config": self.include_config,
520531
"include_eval": self.include_eval,
532+
"include_task": self.include_task,
521533
"validation_enabled": self.validate_on_completion,
522534
}
523535

maseval/core/task.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,20 @@ class TaskProtocol:
5252
priority: int = 0
5353
tags: Dict[str, Any] = field(default_factory=dict)
5454

55+
def to_dict(self) -> Dict[str, Any]:
56+
"""Convert to a JSON-serializable dictionary.
57+
58+
Returns:
59+
Dictionary with all fields. Enum values are converted to strings.
60+
"""
61+
return {
62+
"timeout_seconds": self.timeout_seconds,
63+
"timeout_action": self.timeout_action.value,
64+
"max_retries": self.max_retries,
65+
"priority": self.priority,
66+
"tags": dict(self.tags),
67+
}
68+
5569

5670
class FrozenDict(dict):
5771
"""A dict subclass that raises ``TaskFrozenError`` on any mutation attempt.

tests/test_core/test_callbacks/test_result_logger.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,43 @@ def test_filter_report_status_and_error_absent(self):
174174
assert filtered["status"] is None
175175
assert filtered["error"] is None
176176

177+
def test_filter_report_includes_task_when_enabled(self):
178+
"""Test that task data is included in filtered report when include_task is True."""
179+
logger = MockResultLogger(include_task=True)
180+
181+
report = {
182+
"task_id": "task_0",
183+
"repeat_idx": 0,
184+
"traces": {},
185+
"config": {},
186+
"eval": {},
187+
"task": {
188+
"query": "What is 2+2?",
189+
"metadata": {"difficulty": "easy"},
190+
"protocol": {"timeout_seconds": None, "timeout_action": "skip", "max_retries": 0, "priority": 0, "tags": {}},
191+
},
192+
}
193+
194+
filtered = logger._filter_report(report)
195+
196+
assert "task" in filtered
197+
assert filtered["task"]["query"] == "What is 2+2?"
198+
assert filtered["task"]["metadata"] == {"difficulty": "easy"}
199+
200+
def test_filter_report_excludes_task_by_default(self):
201+
"""Test that task data is excluded from filtered report by default."""
202+
logger = MockResultLogger()
203+
204+
report = {
205+
"task_id": "task_0",
206+
"repeat_idx": 0,
207+
"task": {"query": "What is 2+2?", "metadata": {}, "protocol": {}},
208+
}
209+
210+
filtered = logger._filter_report(report)
211+
212+
assert "task" not in filtered
213+
177214
def test_filter_report_partial_included(self):
178215
"""Test report filtering with only some fields included."""
179216
logger = MockResultLogger(include_traces=False, include_config=True, include_eval=False)

tests/test_core/test_task_protocol.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,46 @@ def test_tags_isolation(self):
6767

6868
assert "key" not in p2.tags
6969

70+
def test_to_dict_defaults(self):
71+
"""to_dict should return all fields with defaults."""
72+
protocol = TaskProtocol()
73+
result = protocol.to_dict()
74+
75+
assert result == {
76+
"timeout_seconds": None,
77+
"timeout_action": "skip",
78+
"max_retries": 0,
79+
"priority": 0,
80+
"tags": {},
81+
}
82+
83+
def test_to_dict_custom_values(self):
84+
"""to_dict should serialize custom values and enums correctly."""
85+
protocol = TaskProtocol(
86+
timeout_seconds=60.0,
87+
timeout_action=TimeoutAction.RETRY,
88+
max_retries=3,
89+
priority=10,
90+
tags={"category": "hard"},
91+
)
92+
result = protocol.to_dict()
93+
94+
assert result == {
95+
"timeout_seconds": 60.0,
96+
"timeout_action": "retry",
97+
"max_retries": 3,
98+
"priority": 10,
99+
"tags": {"category": "hard"},
100+
}
101+
102+
def test_to_dict_returns_new_dict(self):
103+
"""to_dict should return a new dict, not a reference to internal state."""
104+
protocol = TaskProtocol(tags={"key": "value"})
105+
result = protocol.to_dict()
106+
107+
result["tags"]["key"] = "modified"
108+
assert protocol.tags["key"] == "value"
109+
70110

71111
@pytest.mark.core
72112
class TestTaskWithProtocol:

0 commit comments

Comments
 (0)