@@ -74,6 +74,7 @@ def evaluation_test(
7474 model : List [ModelParam ],
7575 input_messages : Optional [List [InputMessagesParam ]] = None ,
7676 input_dataset : Optional [List [DatasetPathParam ]] = None ,
77+ dataset_adapter : Optional [Callable [[List [Dict [str , Any ]]], Dataset ]] = lambda x : x ,
7778 input_params : Optional [List [InputParam ]] = None ,
7879 rollout_processor : Callable [
7980 [EvaluationRow , ModelParam , InputParam ], List [EvaluationRow ]
@@ -90,8 +91,13 @@ def evaluation_test(
9091
9192 Args:
9293 model: Model identifiers to query.
93- input_messages: Messages to send to the model.
94- input_dataset: Paths to JSONL datasets.
94+ input_messages: Messages to send to the model. This is useful if you
95+ don't have a dataset but can hard-code the messages.
96+ input_dataset: Paths to JSONL datasets. This is useful if you have a
97+ dataset already. Provide a dataset_adapter to convert the input dataset
98+ to a list of EvaluationRows if you have a custom dataset format.
99+ dataset_adapter: Function to convert the input dataset to a list of
100+ EvaluationRows. This is useful if you have a custom dataset format.
95101 input_params: Generation parameters for the model.
96102 rollout_processor: Function used to perform the rollout.
97103 aggregation_method: How to aggregate scores across rows.
@@ -240,16 +246,9 @@ def wrapper_body(**kwargs):
240246 data = load_jsonl (kwargs ["dataset_path" ])
241247 if max_dataset_rows is not None :
242248 data = data [:max_dataset_rows ]
243- input_dataset = []
244- for entry in data :
245- user_query = entry .get ("user_query" ) or entry .get ("prompt" )
246- if not user_query :
247- continue
248- messages = [Message (role = "user" , content = user_query )]
249- row = EvaluationRow (
250- messages = messages ,
251- ground_truth = entry .get ("ground_truth_for_eval" ),
252- )
249+ data = dataset_adapter (data )
250+ input_dataset : List [EvaluationRow ] = []
251+ for row in data :
253252 processed = rollout_processor (
254253 row , model = model_name , input_params = kwargs .get ("input_params" ) or {}
255254 )
0 commit comments