Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions eval_protocol/pytest/default_single_turn_rollout_process.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,16 @@
import asyncio
from typing import List

from openai import AsyncOpenAI
from litellm import acompletion

from eval_protocol.auth import get_fireworks_api_base, get_fireworks_api_key
from eval_protocol.models import EvaluationRow, Message
from eval_protocol.pytest.types import RolloutProcessorConfig


async def default_single_turn_rollout_processor(
rows: List[EvaluationRow], config: RolloutProcessorConfig
) -> List[EvaluationRow]:
"""Generate a single response from a Fireworks model concurrently."""

api_key = get_fireworks_api_key()
api_base = get_fireworks_api_base()
client = AsyncOpenAI(api_key=api_key, base_url=f"{api_base}/inference/v1")
"""Generate a single response from any supported model provider using LiteLLM."""

async def process_row(row: EvaluationRow) -> EvaluationRow:
"""Process a single row asynchronously."""
Expand All @@ -24,10 +19,13 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:

messages_payload = [{"role": m.role, "content": m.content} for m in row.messages]

create_kwargs = dict(model=config.model, messages=messages_payload, **config.input_params)
request_params = {"model": config.model, "messages": messages_payload, **config.input_params}

if row.tools is not None:
create_kwargs["tools"] = row.tools
response = await client.chat.completions.create(**create_kwargs)
request_params["tools"] = row.tools

response = await acompletion(**request_params)

assistant_content = response.choices[0].message.content or ""
tool_calls = response.choices[0].message.tool_calls if response.choices[0].message.tool_calls else None
messages = list(row.messages) + [
Expand Down
Loading