Skip to content

Commit c7fef33

Browse files
author
Dylan Huang
committed
savev
1 parent 057e132 commit c7fef33

4 files changed

Lines changed: 160 additions & 0 deletions

File tree

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from collections.abc import Sequence
2+
3+
from eval_protocol.data_loader.models import (
4+
DataLoaderContext,
5+
DataLoaderResult,
6+
DataLoaderVariant,
7+
EvaluationDataLoader,
8+
)
9+
from eval_protocol.models import EvaluationRow, Message
10+
from eval_protocol.pytest.types import InputMessagesParam
11+
12+
13+
class InlineDataLoader(EvaluationDataLoader):
14+
"""Data loader for inline ``EvaluationRow`` or message payloads."""
15+
16+
rows: Sequence[EvaluationRow] | None = None
17+
messages: Sequence[InputMessagesParam] | None = None
18+
variant_id: str = "inline"
19+
description: str | None = None
20+
21+
def __post_init__(self) -> None:
22+
if self.rows is None and self.messages is None:
23+
raise ValueError("InlineDataLoader requires rows or messages to be provided")
24+
25+
def variants(self) -> Sequence[DataLoaderVariant]:
26+
def _load(ctx: DataLoaderContext) -> DataLoaderResult:
27+
resolved_rows: list[EvaluationRow] = []
28+
if self.rows is not None:
29+
resolved_rows.extend(row.model_copy(deep=True) for row in self.rows)
30+
if self.messages is not None:
31+
for dataset_messages in self.messages:
32+
row_messages: list[Message] = []
33+
for msg in dataset_messages:
34+
if isinstance(msg, Message):
35+
row_messages.append(msg.model_copy(deep=True))
36+
else:
37+
row_messages.append(Message.model_validate(msg))
38+
resolved_rows.append(EvaluationRow(messages=row_messages))
39+
40+
if ctx.max_rows is not None:
41+
resolved_rows = resolved_rows[: ctx.max_rows]
42+
43+
metadata = {
44+
"data_loader_variant_id": self.variant_id,
45+
"data_loader_type": "inline",
46+
"row_count": len(resolved_rows),
47+
}
48+
49+
return DataLoaderResult(
50+
rows=resolved_rows,
51+
source_id=self.variant_id,
52+
source_metadata=metadata,
53+
)
54+
55+
description = self.description or self.variant_id
56+
return [
57+
DataLoaderVariant(
58+
id=self.variant_id,
59+
description=description,
60+
loader=_load,
61+
metadata={"type": "inline"},
62+
)
63+
]
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
"""Data loader abstractions"""
2+
3+
from __future__ import annotations
4+
5+
from collections.abc import Sequence
6+
from typing import Any, Callable
7+
from typing_extensions import Protocol
8+
9+
from pydantic import BaseModel, Field
10+
11+
from eval_protocol.models import EvaluationRow
12+
from eval_protocol.pytest.types import EvaluationTestMode
13+
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
14+
15+
16+
class DataLoaderContext(BaseModel):
17+
"""Context provided to loader variants when materializing data."""
18+
19+
max_rows: int | None = Field(default=None, ge=1, description="Maximum number of rows to load")
20+
preprocess_fn: Callable[[list[EvaluationRow]], list[EvaluationRow]] | None = Field(
21+
default=None, description="Optional preprocessing function for evaluation rows"
22+
)
23+
logger: DatasetLogger = Field(description="Dataset logger for tracking operations")
24+
invocation_id: str = Field(description="Unique identifier for this invocation")
25+
experiment_id: str = Field(description="Unique identifier for this experiment")
26+
mode: EvaluationTestMode = Field(description="The evaluation test mode")
27+
28+
class Config:
29+
arbitrary_types_allowed = True # For Callable and DatasetLogger types
30+
31+
32+
class DataLoaderResult(BaseModel):
33+
"""Rows and metadata returned by a loader variant."""
34+
35+
rows: list[EvaluationRow] = Field(description="List of evaluation rows loaded")
36+
source_id: str = Field(description="Unique identifier for the data source")
37+
source_metadata: dict[str, Any] = Field(
38+
default_factory=dict, description="Additional metadata about the data source"
39+
)
40+
raw_payload: Any | None = Field(default=None, description="Raw payload data if available")
41+
preprocessed: bool = Field(default=False, description="Whether the data has been preprocessed")
42+
43+
class Config:
44+
arbitrary_types_allowed = True # For Any type in raw_payload
45+
46+
47+
class DataLoaderVariant(BaseModel):
48+
"""Single parameterizable variant from a data loader."""
49+
50+
id: str = Field(description="Unique identifier for this variant")
51+
description: str = Field(description="Human-readable description of this variant")
52+
loader: Callable[[DataLoaderContext], DataLoaderResult] = Field(
53+
description="Function that loads data for this variant"
54+
)
55+
metadata: dict[str, Any] = Field(default_factory=dict, description="Additional metadata for this variant")
56+
57+
class Config:
58+
arbitrary_types_allowed = True # For Callable type
59+
60+
def load(self, ctx: DataLoaderContext) -> DataLoaderResult:
61+
"""Load a dataset for this variant using the provided context."""
62+
63+
return self.loader(ctx)
64+
65+
66+
class EvaluationDataLoader(Protocol):
67+
"""Protocol for data loaders that can be consumed by ``evaluation_test``."""
68+
69+
def variants(self) -> Sequence[DataLoaderVariant]:
70+
"""Return parameterizable variants emitted by this loader."""
71+
...

eval_protocol/pytest/evaluation_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import pytest
1212
from tqdm import tqdm
1313

14+
from eval_protocol.data_loader.models import EvaluationDataLoader
1415
from eval_protocol.dataset_logger import default_logger
1516
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
1617
from eval_protocol.human_id import generate_id, num_combinations
@@ -69,6 +70,7 @@ def evaluation_test(
6970
input_messages: Sequence[list[InputMessagesParam] | None] | None = None,
7071
input_dataset: Sequence[DatasetPathParam] | None = None,
7172
input_rows: Sequence[list[EvaluationRow]] | None = None,
73+
input_data_loaders: Sequence[EvaluationDataLoader] | EvaluationDataLoader | None = None,
7274
dataset_adapter: Callable[[list[dict[str, Any]]], Dataset] = default_dataset_adapter, # pyright: ignore[reportExplicitAny]
7375
rollout_processor: RolloutProcessor | None = None,
7476
evaluation_test_kwargs: Sequence[EvaluationInputParam | None] | None = None,
@@ -131,6 +133,7 @@ def evaluation_test(
131133
input_rows: Pre-constructed EvaluationRow objects to use directly. This is useful
132134
when you want to provide EvaluationRow objects with custom metadata, input_messages,
133135
or other fields already populated. Will be passed as "input_dataset" to the test function.
136+
input_loaders: Data loaders to use to load the input dataset.
134137
dataset_adapter: Function to convert the input dataset to a list of
135138
EvaluationRows. This is useful if you have a custom dataset format.
136139
completion_params: Generation parameters for the rollout.
@@ -171,6 +174,11 @@ def evaluation_test(
171174

172175
active_logger: DatasetLogger = logger if logger else default_logger
173176

177+
if input_data_loaders is not None and (
178+
input_dataset is not None or input_messages is not None or input_rows is not None
179+
):
180+
raise ValueError("data_loaders cannot be combined with input_dataset, input_messages, or input_rows.")
181+
174182
# Optional global overrides via environment for ad-hoc experimentation
175183
# EP_INPUT_PARAMS_JSON can contain a JSON object that will be deep-merged
176184
# into input_params (e.g., '{"temperature":0,"extra_body":{"reasoning":{"effort":"low"}}}').
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from eval_protocol.data_loader.inline_data_loader import InlineDataLoader
2+
from eval_protocol.models import EvaluationRow, Message
3+
from eval_protocol.pytest import evaluation_test
4+
from eval_protocol.pytest.default_no_op_rollout_processor import NoOpRolloutProcessor
5+
6+
7+
@evaluation_test(
8+
data_loaders=InlineDataLoader(
9+
messages=[[Message(role="user", content="What is 2 + 2?")]],
10+
),
11+
)
12+
def test_inline_data_loader(row: EvaluationRow) -> EvaluationRow:
13+
"""Inline data loader should feed pre-constructed message bundles."""
14+
15+
assert row.messages[0].content == "What is 2 + 2?"
16+
assert row.input_metadata.dataset_info is not None
17+
assert row.input_metadata.dataset_info.get("data_loader_variant_id") == "inline"
18+
return row

0 commit comments

Comments
 (0)