Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions eval_protocol/pytest/github_action_rollout_processor.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import asyncio
import os
import time
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Dict, List, Optional
import json
import requests
from datetime import datetime, timezone, timedelta
from eval_protocol.models import EvaluationRow, Status
from eval_protocol.data_loader.dynamic_data_loader import DynamicDataLoader
from eval_protocol.types.remote_rollout_processor import DataLoaderConfig

from .rollout_processor import RolloutProcessor
from .types import RolloutProcessorConfig
Expand All @@ -21,7 +20,7 @@ class GithubActionRolloutProcessor(RolloutProcessor):
Expected GitHub Actions workflow:
- Workflow dispatch with inputs: completion_params, metadata (JSON), model_base_url, api_key
- Workflow makes API calls that get traced (e.g., via Fireworks tracing proxy)
- Traces are fetched later via output_data_loader using rollout_id tags
- Traces are fetched later via Fireworks tracing proxy using rollout_id tags

NOTE: GHA has a rate limit of 5000 requests per hour.
"""
Expand All @@ -38,7 +37,6 @@ def __init__(
timeout_seconds: float = 1800.0,
max_find_workflow_retries: int = 5,
github_token: Optional[str] = None,
output_data_loader: Optional[Callable[[DataLoaderConfig], DynamicDataLoader]] = None,
):
self.owner = owner
self.repo = repo
Expand All @@ -52,7 +50,7 @@ def __init__(
self.timeout_seconds = timeout_seconds
self.max_find_workflow_retries = max_find_workflow_retries
self.github_token = github_token
self._output_data_loader = output_data_loader or default_fireworks_output_data_loader
self._output_data_loader = default_fireworks_output_data_loader
Comment thread
xzrderek marked this conversation as resolved.
Outdated

def _headers(self) -> Dict[str, str]:
headers = {"Accept": "application/vnd.github+json"}
Expand Down
8 changes: 3 additions & 5 deletions eval_protocol/pytest/remote_rollout_processor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import time
from typing import Any, Dict, List, Optional, Callable
from typing import Any, Dict, List, Optional

import requests

Expand All @@ -26,8 +26,7 @@ class RemoteRolloutProcessor(RolloutProcessor):
"""
Rollout processor that triggers a remote HTTP server to perform the rollout.

By default, fetches traces from the Fireworks tracing proxy using rollout_id tags.
You can provide a custom output_data_loader for different tracing backends.
Fetches traces from the Fireworks tracing proxy using rollout_id tags.

See https://evalprotocol.io/tutorial/remote-rollout-processor for documentation.
"""
Expand All @@ -39,7 +38,6 @@ def __init__(
model_base_url: str = "https://tracing.fireworks.ai",
poll_interval: float = 1.0,
timeout_seconds: float = 120.0,
output_data_loader: Optional[Callable[[DataLoaderConfig], DynamicDataLoader]] = None,
):
# Prefer constructor-provided configuration. These can be overridden via
# config.kwargs at call time for backward compatibility.
Expand All @@ -52,7 +50,7 @@ def __init__(
self._model_base_url = _ep_model_base_url
self._poll_interval = poll_interval
self._timeout_seconds = timeout_seconds
self._output_data_loader = output_data_loader or default_fireworks_output_data_loader
self._output_data_loader = default_fireworks_output_data_loader
self._tracing_adapter = FireworksTracingAdapter(base_url=self._model_base_url)

def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
Expand Down
19 changes: 0 additions & 19 deletions tests/github_actions/test_github_actions_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
from eval_protocol.models import EvaluationRow, InputMetadata
from eval_protocol.pytest import evaluation_test
from eval_protocol.pytest.github_action_rollout_processor import GithubActionRolloutProcessor
from eval_protocol.types.remote_rollout_processor import DataLoaderConfig
from eval_protocol.adapters.fireworks_tracing import FireworksTracingAdapter
from eval_protocol.utils.evaluation_row_utils import filter_longest_conversation

ROLLOUT_IDS = set()

Expand All @@ -29,21 +26,6 @@ def check_rollout_coverage():
assert len(ROLLOUT_IDS) == 3, f"Expected to see 3 rollout_ids, but only saw {ROLLOUT_IDS}"


def fetch_fireworks_traces(config: DataLoaderConfig) -> List[EvaluationRow]:
global ROLLOUT_IDS # Track all rollout_ids we've seen
ROLLOUT_IDS.add(config.rollout_id)

base_url = config.model_base_url or "https://tracing.fireworks.ai"
adapter = FireworksTracingAdapter(base_url=base_url)
return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=5)


def fireworks_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader:
return DynamicDataLoader(
generators=[lambda: fetch_fireworks_traces(config)], preprocess_fn=filter_longest_conversation
)


def rows() -> List[EvaluationRow]:
return [
EvaluationRow(input_metadata=InputMetadata(row_id=str(i)))
Expand All @@ -68,7 +50,6 @@ def rows() -> List[EvaluationRow]:
ref=os.getenv("GITHUB_REF", "main"),
poll_interval=3.0, # For multi-turn, you'll likely want higher poll interval
timeout_seconds=300,
output_data_loader=fireworks_output_data_loader,
),
)
async def test_github_actions_rollout(row: EvaluationRow) -> EvaluationRow:
Expand Down
24 changes: 1 addition & 23 deletions tests/remote_server/test_remote_fireworks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# AUTO SERVER STARTUP: Server is automatically started and stopped by the test

import os
import subprocess
import socket
import time
Expand All @@ -13,9 +12,6 @@
from eval_protocol.models import EvaluationRow, Message, EvaluateResult
from eval_protocol.pytest import evaluation_test
from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor
from eval_protocol.adapters.fireworks_tracing import FireworksTracingAdapter
from eval_protocol.utils.evaluation_row_utils import filter_longest_conversation
from eval_protocol.types.remote_rollout_processor import DataLoaderConfig

ROLLOUT_IDS = set()

Expand Down Expand Up @@ -78,21 +74,6 @@ def check_rollout_coverage():
assert len(ROLLOUT_IDS) == 3, f"Expected to see 3 rollout_ids, but only saw {ROLLOUT_IDS}"


def fetch_fireworks_traces(config: DataLoaderConfig) -> List[EvaluationRow]:
global ROLLOUT_IDS # Track all rollout_ids we've seen
ROLLOUT_IDS.add(config.rollout_id)

base_url = config.model_base_url or "https://tracing.fireworks.ai"
adapter = FireworksTracingAdapter(base_url=base_url)
return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=7)


def fireworks_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader:
return DynamicDataLoader(
generators=[lambda: fetch_fireworks_traces(config)], preprocess_fn=filter_longest_conversation
)


def rows() -> List[EvaluationRow]:
"""Generate local rows with rich input_metadata to verify it survives remote traces."""
base_dataset_info = {
Expand All @@ -118,7 +99,6 @@ def rows() -> List[EvaluationRow]:
rollout_processor=RemoteRolloutProcessor(
remote_base_url=f"http://127.0.0.1:{SERVER_PORT}",
timeout_seconds=180,
output_data_loader=fireworks_output_data_loader,
),
)
async def test_remote_rollout_and_fetch_fireworks(row: EvaluationRow) -> EvaluationRow:
Expand All @@ -129,13 +109,11 @@ async def test_remote_rollout_and_fetch_fireworks(row: EvaluationRow) -> Evaluat
- fetch traces from Langfuse via Fireworks tracing proxy filtered by metadata via output_data_loader; FAIL if none found
"""
row.evaluation_result = EvaluateResult(score=0.0, reason="Dummy evaluation result")
ROLLOUT_IDS.add(row.execution_metadata.rollout_id)

assert row.messages[0].content == "What is the capital of France?", "Row should have correct message content"
assert len(row.messages) > 1, "Row should have a response. If this fails, we fellback to the original row."

assert row.execution_metadata.rollout_id in ROLLOUT_IDS, (
f"Row rollout_id {row.execution_metadata.rollout_id} should be in tracked rollout_ids: {ROLLOUT_IDS}"
)
assert row.input_metadata.completion_params["model"] == "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"
assert row.input_metadata.completion_params["temperature"] == 0.5, "Row should have temperature at top level"

Expand Down
16 changes: 0 additions & 16 deletions tests/remote_server/test_remote_fireworks_propagate_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
from eval_protocol.models import EvaluationRow, Message, Status, EvaluateResult
from eval_protocol.pytest import evaluation_test
from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor
from eval_protocol.adapters.fireworks_tracing import FireworksTracingAdapter
from eval_protocol.utils.evaluation_row_utils import filter_longest_conversation
from eval_protocol.types.remote_rollout_processor import DataLoaderConfig


def find_available_port() -> int:
Expand Down Expand Up @@ -67,18 +64,6 @@ def setup_remote_server():
process.wait()


def fetch_fireworks_traces(config: DataLoaderConfig) -> List[EvaluationRow]:
base_url = config.model_base_url or "https://tracing.fireworks.ai"
adapter = FireworksTracingAdapter(base_url=base_url)
return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=7)


def fireworks_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader:
return DynamicDataLoader(
generators=[lambda: fetch_fireworks_traces(config)], preprocess_fn=filter_longest_conversation
)


def rows() -> List[EvaluationRow]:
row = EvaluationRow(messages=[Message(role="user", content="What is the capital of France?")])
return [row]
Expand All @@ -92,7 +77,6 @@ def rows() -> List[EvaluationRow]:
rollout_processor=RemoteRolloutProcessor(
remote_base_url=f"http://127.0.0.1:{SERVER_PORT}",
timeout_seconds=120,
output_data_loader=fireworks_output_data_loader,
),
)
async def test_remote_rollout_and_fetch_fireworks_propagate_status(row: EvaluationRow) -> EvaluationRow:
Expand Down
55 changes: 1 addition & 54 deletions tests/remote_server/test_remote_langfuse.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,3 @@
# MANUAL SERVER STARTUP REQUIRED:
#
# For Python server testing, start:
# python -m tests.remote_server.remote_server (runs on http://127.0.0.1:3000)
#
# For TypeScript server testing, start:
# cd tests/remote_server/typescript-server
# npm install
# npm start
#
# The TypeScript server should be running on http://127.0.0.1:3000
# You only need to start one of the servers!

import os
from typing import List

Expand All @@ -20,35 +7,6 @@
from eval_protocol.models import EvaluationRow, Message
from eval_protocol.pytest import evaluation_test
from eval_protocol.pytest.remote_rollout_processor import RemoteRolloutProcessor
from eval_protocol.adapters.langfuse import create_langfuse_adapter
from eval_protocol.utils.evaluation_row_utils import filter_longest_conversation
from eval_protocol.types.remote_rollout_processor import DataLoaderConfig
Comment thread
xzrderek marked this conversation as resolved.

ROLLOUT_IDS = set()


@pytest.fixture(autouse=True)
def check_rollout_coverage():
"""Ensure we processed all expected rollout_ids"""
global ROLLOUT_IDS
ROLLOUT_IDS.clear()
yield

assert len(ROLLOUT_IDS) == 3, f"Expected to see {ROLLOUT_IDS} rollout_ids, but only saw {ROLLOUT_IDS}"


def fetch_langfuse_traces(config: DataLoaderConfig) -> List[EvaluationRow]:
global ROLLOUT_IDS # Track all rollout_ids we've seen
ROLLOUT_IDS.add(config.rollout_id)

adapter = create_langfuse_adapter()
return adapter.get_evaluation_rows(tags=[f"rollout_id:{config.rollout_id}"], max_retries=5)


def langfuse_output_data_loader(config: DataLoaderConfig) -> DynamicDataLoader:
return DynamicDataLoader(
generators=[lambda: fetch_langfuse_traces(config)], preprocess_fn=filter_longest_conversation
)


def rows() -> List[EvaluationRow]:
Expand All @@ -62,25 +20,14 @@ def rows() -> List[EvaluationRow]:
data_loaders=DynamicDataLoader(
generators=[rows],
),
rollout_processor=RemoteRolloutProcessor(
remote_base_url="http://127.0.0.1:3000",
timeout_seconds=30,
output_data_loader=langfuse_output_data_loader,
model_base_url="https://tracing.fireworks.ai/project_id/cmg5fd57b0006y107kuxkcrhk",
),
rollout_processor=RemoteRolloutProcessor(remote_base_url="http://127.0.0.1:3000", timeout_seconds=30),
)
async def test_remote_rollout_and_fetch_langfuse(row: EvaluationRow) -> EvaluationRow:
"""
End-to-end test:
- REQUIRES MANUAL SERVER STARTUP: python -m tests.remote_server.remote_server
- trigger remote rollout via RemoteRolloutProcessor (calls init/status)
- fetch traces from Langfuse filtered by metadata via output_data_loader; FAIL if none found
"""
assert row.messages[0].content == "What is the capital of France?", "Row should have correct message content"
assert len(row.messages) > 1, "Row should have a response. If this fails, we fellback to the original row."

assert row.execution_metadata.rollout_id in ROLLOUT_IDS, (
f"Row rollout_id {row.execution_metadata.rollout_id} should be in tracked rollout_ids: {ROLLOUT_IDS}"
)

return row
1 change: 0 additions & 1 deletion tests/remote_server/typescript-server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ from eval_protocol import (
data_loaders=[InlineDataLoader(messages=[[Message(role="user", content="Hello")]])],
rollout_processor=RemoteRolloutProcessor(
remote_base_url="http://localhost:3000",
output_data_loader=create_output_data_loader,
)
)
def test_remote_http(row: EvaluationRow) -> EvaluationRow:
Expand Down
Loading