|
4 | 4 |
|
5 | 5 | from eval_protocol.auth import get_fireworks_api_base, get_fireworks_api_key |
6 | 6 | from eval_protocol.models import CompletionParams, EvaluationRow, InputMetadata, Message |
7 | | -from eval_protocol.pytest.types import ModelParam, RolloutProcessorConfig |
| 7 | +from eval_protocol.pytest.types import Dataset, ModelParam, RolloutProcessorConfig |
8 | 8 |
|
9 | 9 |
|
10 | | -def default_single_turn_rollout_processor(row: EvaluationRow, config: RolloutProcessorConfig) -> List[EvaluationRow]: |
| 10 | +def default_single_turn_rollout_processor( |
| 11 | + rows: List[EvaluationRow], config: RolloutProcessorConfig |
| 12 | +) -> List[EvaluationRow]: |
11 | 13 | """Generate a single response from a Fireworks model.""" |
12 | 14 |
|
13 | 15 | api_key = get_fireworks_api_key() |
14 | 16 | api_base = get_fireworks_api_base() |
15 | 17 | client = OpenAI(api_key=api_key, base_url=f"{api_base}/inference/v1") |
16 | 18 |
|
17 | | - if len(row.messages) == 0: |
18 | | - raise ValueError("Messages is empty. Please provide a non-empty dataset") |
| 19 | + dataset: Dataset = [] |
| 20 | + for row in rows: |
| 21 | + if len(row.messages) == 0: |
| 22 | + raise ValueError("Messages is empty. Please provide a non-empty dataset") |
19 | 23 |
|
20 | | - messages_payload = [{"role": m.role, "content": m.content} for m in row.messages] |
| 24 | + messages_payload = [{"role": m.role, "content": m.content} for m in row.messages] |
21 | 25 |
|
22 | | - response = client.chat.completions.create(model=config.model, messages=messages_payload, **config.input_params) |
23 | | - assistant_content = response.choices[0].message.content or "" |
24 | | - messages = list(row.messages) + [Message(role="assistant", content=assistant_content)] |
25 | | - processed = EvaluationRow( |
26 | | - messages=messages, |
27 | | - ground_truth=row.ground_truth, |
28 | | - input_metadata=InputMetadata(completion_params=CompletionParams(model=config.model)), |
29 | | - ) |
30 | | - return [processed] |
| 26 | + response = client.chat.completions.create(model=config.model, messages=messages_payload, **config.input_params) |
| 27 | + assistant_content = response.choices[0].message.content or "" |
| 28 | + messages = list(row.messages) + [Message(role="assistant", content=assistant_content)] |
| 29 | + processed = EvaluationRow( |
| 30 | + messages=messages, |
| 31 | + ground_truth=row.ground_truth, |
| 32 | + input_metadata=InputMetadata(completion_params=CompletionParams(model=config.model)), |
| 33 | + ) |
| 34 | + |
| 35 | + dataset.append(processed) |
| 36 | + return dataset |
0 commit comments